diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9ed7389..897458d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: - name: Check code formatting run: | - black --check . + black --check --line-length=150 . # - name: Check typos # uses: crate-ci/typos@v1.29.10 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..45085dc --- /dev/null +++ b/Makefile @@ -0,0 +1,2 @@ +fmt: + black --line-length=150 . diff --git a/benchmarks/file_io_benchmark.py b/benchmarks/file_io_benchmark.py index e30c7a7..797e1ca 100644 --- a/benchmarks/file_io_benchmark.py +++ b/benchmarks/file_io_benchmark.py @@ -64,9 +64,7 @@ def main(): driver = Driver() driver.add_argument("-i", "--input_paths", nargs="+") driver.add_argument("-n", "--npartitions", type=int, default=None) - driver.add_argument( - "-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream") - ) + driver.add_argument("-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream")) driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024) driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024) driver.add_argument("-o", "--output_name", default="data") diff --git a/benchmarks/gray_sort_benchmark.py b/benchmarks/gray_sort_benchmark.py index 9cb28ad..11e9f12 100644 --- a/benchmarks/gray_sort_benchmark.py +++ b/benchmarks/gray_sort_benchmark.py @@ -73,26 +73,18 @@ def generate_records( subprocess.run(gensort_cmd.split()).check_returncode() runtime_task.add_elapsed_time("generate records (secs)") shm_file.seek(0) - buffer = arrow.py_buffer( - shm_file.read(record_count * record_nbytes) - ) + buffer = arrow.py_buffer(shm_file.read(record_count * record_nbytes)) runtime_task.add_elapsed_time("read records (secs)") # https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout - records = arrow.Array.from_buffers( - arrow.binary(record_nbytes), record_count, [None, buffer] - ) + records = arrow.Array.from_buffers(arrow.binary(record_nbytes), record_count, [None, buffer]) keys = pc.binary_slice(records, 0, key_nbytes) # get first 2 bytes and convert to big-endian uint16 binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary()) - reversed_prefix = pc.binary_reverse(binary_prefix).cast( - arrow.binary(2) - ) + reversed_prefix = pc.binary_reverse(binary_prefix).cast(arrow.binary(2)) uint16_prefix = reversed_prefix.view(arrow.uint16()) buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits) runtime_task.add_elapsed_time("build arrow table (secs)") - yield arrow.Table.from_arrays( - [buckets, keys, records], schema=schema - ) + yield arrow.Table.from_arrays([buckets, keys, records], schema=schema) yield StreamOutput( schema.empty_table(), batch_indices=[batch_idx], @@ -108,9 +100,7 @@ def sort_records( write_io_nbytes=500 * MB, ) -> bool: runtime_task: PythonScriptTask = runtime_ctx.task - data_file_path = os.path.join( - runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat" - ) + data_file_path = os.path.join(runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat") if sort_engine == "polars": input_data = polars.read_parquet( @@ -134,9 +124,7 @@ def sort_records( record_arrays = sorted_table.column("records").chunks runtime_task.add_elapsed_time("convert to chunks (secs)") elif sort_engine == "duckdb": - with duckdb.connect( - database=":memory:", config={"allow_unsigned_extensions": "true"} - ) as conn: + with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn: runtime_task.prepare_connection(conn) input_views = runtime_task.create_input_views(conn, input_datasets) sql_query = "select records from {0} order by keys".format(*input_views) @@ -154,8 +142,7 @@ def sort_records( buffer_mem = memoryview(values) total_write_nbytes = sum( - fout.write(buffer_mem[offset : offset + write_io_nbytes]) - for offset in range(0, len(buffer_mem), write_io_nbytes) + fout.write(buffer_mem[offset : offset + write_io_nbytes]) for offset in range(0, len(buffer_mem), write_io_nbytes) ) assert total_write_nbytes == len(buffer_mem) @@ -164,16 +151,10 @@ def sort_records( return True -def validate_records( - runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str -) -> bool: +def validate_records(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool: for data_path in input_datasets[0].resolved_paths: - summary_path = os.path.join( - output_path, PurePath(data_path).with_suffix(".sum").name - ) - cmdstr = ( - f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m" - ) + summary_path = os.path.join(output_path, PurePath(data_path).with_suffix(".sum").name) + cmdstr = f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m" logging.debug(f"running command: {cmdstr}") result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8") if result.stderr: @@ -185,9 +166,7 @@ def validate_records( return True -def validate_summary( - runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str -) -> bool: +def validate_summary(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool: concated_summary_path = os.path.join(output_path, "merged.sum") with open(concated_summary_path, "wb") as fout: for path in input_datasets[0].resolved_paths: @@ -224,22 +203,13 @@ def generate_random_records( ) range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)] - range_num_records = [ - min(total_num_records, record_range_size * (range_idx + 1)) - begin_at - for range_idx, begin_at in enumerate(range_begin_at) - ] + range_num_records = [min(total_num_records, record_range_size * (range_idx + 1)) - begin_at for range_idx, begin_at in enumerate(range_begin_at)] assert sum(range_num_records) == total_num_records record_range = DataSourceNode( ctx, - ArrowTableDataSet( - arrow.Table.from_arrays( - [range_begin_at, range_num_records], names=["begin_at", "num_records"] - ) - ), - ) - record_range_partitions = DataSetPartitionNode( - ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True + ArrowTableDataSet(arrow.Table.from_arrays([range_begin_at, range_num_records], names=["begin_at", "num_records"])), ) + record_range_partitions = DataSetPartitionNode(ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True) random_records = ArrowStreamNode( ctx, @@ -288,9 +258,7 @@ def gray_sort_benchmark( if input_paths: input_dataset = ParquetDataSet(input_paths) input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths) - logging.warning( - f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files" - ) + logging.warning(f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files") random_records = DataSourceNode(ctx, input_dataset) else: random_records = generate_random_records( @@ -335,12 +303,8 @@ def gray_sort_benchmark( process_func=validate_records, output_name="partitioned_summaries", ) - merged_summaries = DataSetPartitionNode( - ctx, (partitioned_summaries,), npartitions=1 - ) - final_check = PythonScriptNode( - ctx, (merged_summaries,), process_func=validate_summary - ) + merged_summaries = DataSetPartitionNode(ctx, (partitioned_summaries,), npartitions=1) + final_check = PythonScriptNode(ctx, (merged_summaries,), process_func=validate_summary) root = final_check else: root = sorted_records @@ -359,17 +323,11 @@ def main(): driver.add_argument("-n", "--num_data_partitions", type=int, default=None) driver.add_argument("-t", "--num_sort_partitions", type=int, default=None) driver.add_argument("-i", "--input_paths", nargs="+", default=[]) - driver.add_argument( - "-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow") - ) - driver.add_argument( - "-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars") - ) + driver.add_argument("-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow")) + driver.add_argument("-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars")) driver.add_argument("-H", "--hive_partitioning", action="store_true") driver.add_argument("-V", "--validate_results", action="store_true") - driver.add_argument( - "-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit - ) + driver.add_argument("-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit) driver.add_argument( "-M", "--shuffle_memory_limit", @@ -378,12 +336,8 @@ def main(): ) driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8) driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None) - driver.add_argument( - "-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False) - ) - driver.add_argument( - "-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total - ) + driver.add_argument("-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False)) + driver.add_argument("-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total) driver.add_argument("-CP", "--parquet_compression", default=None) driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None) @@ -393,16 +347,9 @@ def main(): total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node - user_args.sort_cpu_limit = ( - 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit - ) - sort_memory_limit = ( - user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu - ) - user_args.total_data_nbytes = ( - user_args.total_data_nbytes - or max(1, driver_args.num_executors) * user_args.memory_per_node - ) + user_args.sort_cpu_limit = 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit + sort_memory_limit = user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu + user_args.total_data_nbytes = user_args.total_data_nbytes or max(1, driver_args.num_executors) * user_args.memory_per_node user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2 user_args.num_sort_partitions = user_args.num_sort_partitions or max( total_num_cpus // user_args.sort_cpu_limit, diff --git a/benchmarks/hash_partition_benchmark.py b/benchmarks/hash_partition_benchmark.py index 3ca7bb4..f90c965 100644 --- a/benchmarks/hash_partition_benchmark.py +++ b/benchmarks/hash_partition_benchmark.py @@ -70,18 +70,12 @@ def main(): driver.add_argument("-i", "--input_paths", nargs="+", required=True) driver.add_argument("-n", "--npartitions", type=int, default=None) driver.add_argument("-c", "--hash_columns", nargs="+", required=True) - driver.add_argument( - "-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow") - ) + driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")) driver.add_argument("-S", "--partition_stats", action="store_true") driver.add_argument("-W", "--use_parquet_writer", action="store_true") driver.add_argument("-H", "--hive_partitioning", action="store_true") - driver.add_argument( - "-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit - ) - driver.add_argument( - "-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit - ) + driver.add_argument("-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit) + driver.add_argument("-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit) driver.add_argument("-NC", "--cpus_per_node", type=int, default=192) driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB) diff --git a/benchmarks/urls_sort_benchmark.py b/benchmarks/urls_sort_benchmark.py index 4065c04..b5bbcda 100644 --- a/benchmarks/urls_sort_benchmark.py +++ b/benchmarks/urls_sort_benchmark.py @@ -29,9 +29,7 @@ def urls_sort_benchmark( delim=r"\t", ) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=num_data_partitions - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=num_data_partitions) imported_urls = SqlEngineNode( ctx, @@ -80,16 +78,10 @@ def urls_sort_benchmark_v2( sort_cpu_limit=8, sort_memory_limit=None, ): - dataset = sp.read_csv( - input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t" - ) + dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t") data_partitions = dataset.repartition(num_data_partitions) - urls_partitions = data_partitions.repartition( - num_hash_partitions, hash_by="urlstr", engine_type=engine_type - ) - sorted_urls = urls_partitions.partial_sort( - by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit - ) + urls_partitions = data_partitions.repartition(num_hash_partitions, hash_by="urlstr", engine_type=engine_type) + sorted_urls = urls_partitions.partial_sort(by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit) sorted_urls.write_parquet(output_path) @@ -106,12 +98,8 @@ def main(): num_nodes = driver_args.num_executors cpus_per_node = 120 partition_rounds = 2 - user_args.num_data_partitions = ( - user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds - ) - user_args.num_hash_partitions = ( - user_args.num_hash_partitions or num_nodes * cpus_per_node - ) + user_args.num_data_partitions = user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds + user_args.num_hash_partitions = user_args.num_hash_partitions or num_nodes * cpus_per_node # v1 plan = urls_sort_benchmark(**vars(user_args)) diff --git a/examples/fstest.py b/examples/fstest.py index 21967b9..0ce8bd6 100644 --- a/examples/fstest.py +++ b/examples/fstest.py @@ -80,9 +80,7 @@ def check_data(actual: bytes, expected: bytes, offset: int) -> None: ) expected = expected[index : index + 16] actual = actual[index : index + 16] - raise ValueError( - f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}" - ) + raise ValueError(f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}") def generate_data(offset: int, length: int) -> bytes: @@ -92,16 +90,10 @@ def generate_data(offset: int, length: int) -> bytes: """ istart = offset // 4 iend = (offset + length + 3) // 4 - return ( - np.arange(istart, iend) - .astype(np.uint32) - .tobytes()[offset % 4 : offset % 4 + length] - ) + return np.arange(istart, iend).astype(np.uint32).tobytes()[offset % 4 : offset % 4 + length] -def iter_io_slice( - offset: int, length: int, block_size: Union[int, Tuple[int, int]] -) -> Iterator[Tuple[int, int]]: +def iter_io_slice(offset: int, length: int, block_size: Union[int, Tuple[int, int]]) -> Iterator[Tuple[int, int]]: """ Generate the IO (offset, size) for the slice [offset, offset + length) with the given block size. `block_size` can be an integer or a range [start, end]. If a range is provided, the IO size will be randomly selected from the range. @@ -161,9 +153,7 @@ def fstest( if output_path is not None: os.makedirs(output_path, exist_ok=True) - df = sp.from_items( - [{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)] - ) + df = sp.from_items([{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)]) df = df.repartition(npartitions, by_rows=True) stats = df.map(lambda x: fswrite(x["path"], size, blocksize)).to_pandas() logging.info(f"write stats:\n{stats}") @@ -187,18 +177,14 @@ if __name__ == "__main__": python example/fstest.py -o 'fstest' -j 8 -s 1G -i 'fstest/*' """ parser = argparse.ArgumentParser() - parser.add_argument( - "-o", "--output_path", type=str, help="The output path to write data to." - ) + parser.add_argument("-o", "--output_path", type=str, help="The output path to write data to.") parser.add_argument( "-i", "--input_path", type=str, help="The input path to read data from. If -o is provided, this is ignored.", ) - parser.add_argument( - "-j", "--npartitions", type=int, help="The number of parallel jobs", default=10 - ) + parser.add_argument("-j", "--npartitions", type=int, help="The number of parallel jobs", default=10) parser.add_argument( "-s", "--size", diff --git a/examples/shuffle_data.py b/examples/shuffle_data.py index 82cf493..34e2141 100644 --- a/examples/shuffle_data.py +++ b/examples/shuffle_data.py @@ -52,9 +52,7 @@ def shuffle_data( npartitions=num_out_data_partitions, partition_by_rows=True, ) - shuffled_urls = StreamCopy( - ctx, (repartitioned,), output_name="data_copy", cpu_limit=1 - ) + shuffled_urls = StreamCopy(ctx, (repartitioned,), output_name="data_copy", cpu_limit=1) plan = LogicalPlan(ctx, shuffled_urls) return plan @@ -66,9 +64,7 @@ def main(): driver.add_argument("-nd", "--num_data_partitions", type=int, default=1024) driver.add_argument("-nh", "--num_hash_partitions", type=int, default=3840) driver.add_argument("-no", "--num_out_data_partitions", type=int, default=1920) - driver.add_argument( - "-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow") - ) + driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")) driver.add_argument("-x", "--skip_hash_partition", action="store_true") plan = shuffle_data(**driver.get_arguments()) driver.run(plan) diff --git a/examples/shuffle_mock_urls.py b/examples/shuffle_mock_urls.py index ffde6e6..2c9bf1c 100644 --- a/examples/shuffle_mock_urls.py +++ b/examples/shuffle_mock_urls.py @@ -11,9 +11,7 @@ from smallpond.logical.node import ( ) -def shuffle_mock_urls( - input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb" -) -> LogicalPlan: +def shuffle_mock_urls(input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb") -> LogicalPlan: ctx = Context() dataset = ParquetDataSet(input_paths) data_files = DataSourceNode(ctx, dataset) @@ -61,9 +59,7 @@ def main(): driver.add_argument("-i", "--input_paths", nargs="+") driver.add_argument("-n", "--npartitions", type=int, default=500) driver.add_argument("-s", "--sort_rand_keys", action="store_true") - driver.add_argument( - "-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow") - ) + driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")) plan = shuffle_mock_urls(**driver.get_arguments()) driver.run(plan) diff --git a/examples/sort_mock_urls.py b/examples/sort_mock_urls.py index eb4042e..e4ffc96 100644 --- a/examples/sort_mock_urls.py +++ b/examples/sort_mock_urls.py @@ -20,9 +20,7 @@ from smallpond.logical.node import ( class SortUrlsNode(ArrowComputeNode): - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: logging.info(f"sorting urls by 'host', table shape: {input_tables[0].shape}") return input_tables[0].sort_by("host") @@ -90,9 +88,7 @@ def sort_mock_urls( def main(): driver = Driver() - driver.add_argument( - "-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"] - ) + driver.add_argument("-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]) driver.add_argument("-n", "--npartitions", type=int, default=10) driver.add_argument("-e", "--engine_type", default="duckdb") diff --git a/examples/sort_mock_urls_v2.py b/examples/sort_mock_urls_v2.py index d642c3f..6a57195 100644 --- a/examples/sort_mock_urls_v2.py +++ b/examples/sort_mock_urls_v2.py @@ -5,12 +5,8 @@ import smallpond from smallpond.dataframe import Session -def sort_mock_urls_v2( - sp: Session, input_paths: List[str], output_path: str, npartitions: int -): - dataset = sp.read_csv( - input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t" - ).repartition(npartitions) +def sort_mock_urls_v2(sp: Session, input_paths: List[str], output_path: str, npartitions: int): + dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t").repartition(npartitions) urls = dataset.map( """ split_part(urlstr, '/', 1) as host, @@ -25,9 +21,7 @@ def sort_mock_urls_v2( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"] - ) + parser.add_argument("-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]) parser.add_argument("-o", "--output_path", type=str, default="sort_mock_urls") parser.add_argument("-n", "--npartitions", type=int, default=10) args = parser.parse_args() diff --git a/smallpond/common.py b/smallpond/common.py index b73981b..a4021cd 100644 --- a/smallpond/common.py +++ b/smallpond/common.py @@ -41,9 +41,7 @@ def clamp_row_group_size(val, minval=DEFAULT_ROW_GROUP_SIZE, maxval=MAX_ROW_GROU return clamp_value(val, minval, maxval) -def clamp_row_group_bytes( - val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES -): +def clamp_row_group_bytes(val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES): return clamp_value(val, minval, maxval) @@ -74,10 +72,7 @@ def first_value_in_dict(d: Dict[K, V]) -> V: def split_into_cols(items: List[V], npartitions: int) -> List[List[V]]: none = object() chunks = [items[i : i + npartitions] for i in range(0, len(items), npartitions)] - return [ - [x for x in col if x is not none] - for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none) - ] + return [[x for x in col if x is not none] for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none)] def split_into_rows(items: List[V], npartitions: int) -> List[List[V]]: @@ -101,10 +96,7 @@ def get_nth_partition(items: List[V], n: int, npartitions: int) -> List[V]: start = n * large_partition_size items_in_partition = items[start : start + large_partition_size] else: - start = ( - large_partition_size * num_large_partitions - + (n - num_large_partitions) * small_partition_size - ) + start = large_partition_size * num_large_partitions + (n - num_large_partitions) * small_partition_size items_in_partition = items[start : start + small_partition_size] return items_in_partition diff --git a/smallpond/contrib/copy_table.py b/smallpond/contrib/copy_table.py index 1b309b0..a0f7515 100644 --- a/smallpond/contrib/copy_table.py +++ b/smallpond/contrib/copy_table.py @@ -8,17 +8,13 @@ from smallpond.logical.node import ArrowComputeNode, ArrowStreamNode class CopyArrowTable(ArrowComputeNode): - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: logger.info(f"copying table: {input_tables[0].num_rows} rows ...") return input_tables[0] class StreamCopy(ArrowStreamNode): - def process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[arrow.Table]: + def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: for batch in input_readers[0]: logger.info(f"copying batch: {batch.num_rows} rows ...") yield arrow.Table.from_batches([batch]) diff --git a/smallpond/contrib/log_dataset.py b/smallpond/contrib/log_dataset.py index e2b20d9..948c353 100644 --- a/smallpond/contrib/log_dataset.py +++ b/smallpond/contrib/log_dataset.py @@ -25,9 +25,7 @@ class LogDataSetTask(PythonScriptTask): class LogDataSet(PythonScriptNode): - def __init__( - self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs - ) -> None: + def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs) -> None: super().__init__(ctx, input_deps, **kwargs) self.num_rows = num_rows diff --git a/smallpond/contrib/warc.py b/smallpond/contrib/warc.py index 7f4e8dc..cda8727 100644 --- a/smallpond/contrib/warc.py +++ b/smallpond/contrib/warc.py @@ -29,16 +29,12 @@ class ImportWarcFiles(PythonScriptNode): ] ) - def import_warc_file( - self, warc_path: PurePath, parquet_path: PurePath - ) -> Tuple[int, int]: + def import_warc_file(self, warc_path: PurePath, parquet_path: PurePath) -> Tuple[int, int]: total_size = 0 docs = [] with open(warc_path, "rb") as warc_file: - zstd_reader = zstd.ZstdDecompressor().stream_reader( - warc_file, read_size=16 * MB - ) + zstd_reader = zstd.ZstdDecompressor().stream_reader(warc_file, read_size=16 * MB) for record in ArchiveIterator(zstd_reader): if record.rec_type == "response": url = record.rec_headers.get_header("WARC-Target-URI") @@ -48,9 +44,7 @@ class ImportWarcFiles(PythonScriptNode): total_size += len(content) docs.append((url, domain, date, content)) - table = arrow.Table.from_arrays( - [arrow.array(column) for column in zip(*docs)], schema=self.schema - ) + table = arrow.Table.from_arrays([arrow.array(column) for column in zip(*docs)], schema=self.schema) dump_to_parquet_files(table, parquet_path.parent, parquet_path.name) return len(docs), total_size @@ -60,14 +54,9 @@ class ImportWarcFiles(PythonScriptNode): input_datasets: List[DataSet], output_path: str, ) -> bool: - warc_paths = [ - PurePath(warc_path) - for dataset in input_datasets - for warc_path in dataset.resolved_paths - ] + warc_paths = [PurePath(warc_path) for dataset in input_datasets for warc_path in dataset.resolved_paths] parquet_paths = [ - PurePath(output_path) - / f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}" + PurePath(output_path) / f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}" for path_index, warc_path in enumerate(warc_paths) ] @@ -75,24 +64,16 @@ class ImportWarcFiles(PythonScriptNode): for warc_path, parquet_path in zip(warc_paths, parquet_paths): try: doc_count, total_size = self.import_warc_file(warc_path, parquet_path) - logger.info( - f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'" - ) + logger.info(f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'") except Exception as ex: - logger.opt(exception=ex).error( - f"failed to import web pages from file '{warc_path}'" - ) + logger.opt(exception=ex).error(f"failed to import web pages from file '{warc_path}'") return False return True class ExtractHtmlBody(ArrowStreamNode): - unicode_punctuation = "".join( - chr(i) - for i in range(sys.maxunicode) - if unicodedata.category(chr(i)).startswith("P") - ) + unicode_punctuation = "".join(chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")) separator_str = string.whitespace + string.punctuation + unicode_punctuation translator = str.maketrans(separator_str, " " * len(separator_str)) @@ -117,27 +98,19 @@ class ExtractHtmlBody(ArrowStreamNode): tokens.extend(self.split_string(doc.get_text(" ", strip=True).lower())) return tokens except Exception as ex: - logger.opt(exception=ex).error( - f"failed to extract tokens from {url.as_py()}" - ) + logger.opt(exception=ex).error(f"failed to extract tokens from {url.as_py()}") return [] - def process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[arrow.Table]: + def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: for batch in input_readers[0]: urls, domains, dates, contents = batch.columns doc_tokens = [] try: for i, (url, content) in enumerate(zip(urls, contents)): tokens = self.extract_tokens(url, content) - logger.info( - f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}" - ) + logger.info(f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}") doc_tokens.append(tokens) - yield arrow.Table.from_arrays( - [urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema - ) + yield arrow.Table.from_arrays([urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema) except Exception as ex: logger.opt(exception=ex).error(f"failed to extract tokens") break diff --git a/smallpond/dataframe.py b/smallpond/dataframe.py index 4974910..9d8fd0f 100644 --- a/smallpond/dataframe.py +++ b/smallpond/dataframe.py @@ -35,9 +35,7 @@ class Session(SessionBase): Subsequent DataFrames can reuse the tasks to avoid recomputation. """ - def read_csv( - self, paths: Union[str, List[str]], schema: Dict[str, str], delim="," - ) -> DataFrame: + def read_csv(self, paths: Union[str, List[str]], schema: Dict[str, str], delim=",") -> DataFrame: """ Create a DataFrame from CSV files. """ @@ -55,15 +53,11 @@ class Session(SessionBase): """ Create a DataFrame from Parquet files. """ - dataset = ParquetDataSet( - paths, columns=columns, union_by_name=union_by_name, recursive=recursive - ) + dataset = ParquetDataSet(paths, columns=columns, union_by_name=union_by_name, recursive=recursive) plan = DataSourceNode(self._ctx, dataset) return DataFrame(self, plan) - def read_json( - self, paths: Union[str, List[str]], schema: Dict[str, str] - ) -> DataFrame: + def read_json(self, paths: Union[str, List[str]], schema: Dict[str, str]) -> DataFrame: """ Create a DataFrame from JSON files. """ @@ -115,9 +109,7 @@ class Session(SessionBase): c = sp.partial_sql("select * from {0} join {1} on a.id = b.id", a, b) """ - plan = SqlEngineNode( - self._ctx, tuple(input.plan for input in inputs), query, **kwargs - ) + plan = SqlEngineNode(self._ctx, tuple(input.plan for input in inputs), query, **kwargs) recompute = any(input.need_recompute for input in inputs) return DataFrame(self, plan, recompute=recompute) @@ -177,26 +169,15 @@ class Session(SessionBase): """ Return the total number of tasks and the number of tasks that are finished. """ - dataset_refs = [ - task._dataset_ref - for tasks in self._node_to_tasks.values() - for task in tasks - if task._dataset_ref is not None - ] - ready_tasks, _ = ray.wait( - dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False - ) + dataset_refs = [task._dataset_ref for tasks in self._node_to_tasks.values() for task in tasks if task._dataset_ref is not None] + ready_tasks, _ = ray.wait(dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False) return len(dataset_refs), len(ready_tasks) def _all_tasks_finished(self) -> bool: """ Check if all tasks are finished. """ - dataset_refs = [ - task._dataset_ref - for tasks in self._node_to_tasks.values() - for task in tasks - ] + dataset_refs = [task._dataset_ref for tasks in self._node_to_tasks.values() for task in tasks] try: ray.get(dataset_refs, timeout=0) except Exception: @@ -232,12 +213,8 @@ class DataFrame: # optimize the plan if self.optimized_plan is None: logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}") - self.optimized_plan = Optimizer( - exclude_nodes=set(self.session._node_to_tasks.keys()) - ).visit(self.plan) - logger.info( - f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}" - ) + self.optimized_plan = Optimizer(exclude_nodes=set(self.session._node_to_tasks.keys())).visit(self.plan) + logger.info(f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}") # return the tasks if already created if tasks := self.session._node_to_tasks.get(self.optimized_plan): return tasks @@ -281,9 +258,7 @@ class DataFrame: """ for retry_count in range(3): try: - return ray.get( - [task.run_on_ray() for task in self._get_or_create_tasks()] - ) + return ray.get([task.run_on_ray() for task in self._get_or_create_tasks()]) except ray.exceptions.RuntimeEnvSetupError as e: # XXX: Ray may raise this error when a worker is interrupted. # ``` @@ -361,9 +336,7 @@ class DataFrame: ) elif hash_by is not None: hash_columns = [hash_by] if isinstance(hash_by, str) else hash_by - plan = HashPartitionNode( - self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs - ) + plan = HashPartitionNode(self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs) else: plan = EvenlyDistributedPartitionNode( self.session._ctx, @@ -420,9 +393,7 @@ class DataFrame: ) return DataFrame(self.session, plan, recompute=self.need_recompute) - def filter( - self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs - ) -> DataFrame: + def filter(self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs) -> DataFrame: """ Filter out rows that don't satisfy the given predicate. @@ -453,13 +424,9 @@ class DataFrame: table = tables[0] return table.filter([func(row) for row in table.to_pylist()]) - plan = ArrowBatchNode( - self.session._ctx, (self.plan,), process_func=process_func, **kwargs - ) + plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs) else: - raise ValueError( - "condition must be a SQL expression or a predicate function" - ) + raise ValueError("condition must be a SQL expression or a predicate function") return DataFrame(self.session, plan, recompute=self.need_recompute) def map( @@ -510,18 +477,14 @@ class DataFrame: """ if isinstance(sql := sql_or_func, str): - plan = SqlEngineNode( - self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs - ) + plan = SqlEngineNode(self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs) elif isinstance(func := sql_or_func, Callable): def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table: output_rows = [func(row) for row in tables[0].to_pylist()] return arrow.Table.from_pylist(output_rows, schema=schema) - plan = ArrowBatchNode( - self.session._ctx, (self.plan,), process_func=process_func, **kwargs - ) + plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs) else: raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}") return DataFrame(self.session, plan, recompute=self.need_recompute) @@ -555,20 +518,14 @@ class DataFrame: """ if isinstance(sql := sql_or_func, str): - plan = SqlEngineNode( - self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs - ) + plan = SqlEngineNode(self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs) elif isinstance(func := sql_or_func, Callable): def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table: - output_rows = [ - item for row in tables[0].to_pylist() for item in func(row) - ] + output_rows = [item for row in tables[0].to_pylist() for item in func(row)] return arrow.Table.from_pylist(output_rows, schema=schema) - plan = ArrowBatchNode( - self.session._ctx, (self.plan,), process_func=process_func, **kwargs - ) + plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs) else: raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}") return DataFrame(self.session, plan, recompute=self.need_recompute) @@ -642,9 +599,7 @@ class DataFrame: sp.wait(o1, o2) """ - plan = DataSinkNode( - self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy" - ) + plan = DataSinkNode(self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy") return DataFrame(self.session, plan, recompute=self.need_recompute) # inspection @@ -710,6 +665,4 @@ class DataFrame: """ datasets = self._compute() with ThreadPoolExecutor() as pool: - return arrow.concat_tables( - pool.map(lambda dataset: dataset.to_arrow_table(), datasets) - ) + return arrow.concat_tables(pool.map(lambda dataset: dataset.to_arrow_table(), datasets)) diff --git a/smallpond/execution/driver.py b/smallpond/execution/driver.py index 4c17b46..92555aa 100644 --- a/smallpond/execution/driver.py +++ b/smallpond/execution/driver.py @@ -28,43 +28,23 @@ class Driver(object): self.all_args = None def _create_driver_args_parser(self): - parser = argparse.ArgumentParser( - prog="driver.py", description="Smallpond Driver", add_help=False - ) - parser.add_argument( - "mode", choices=["executor", "scheduler", "ray"], default="executor" - ) - parser.add_argument( - "--exec_id", default=socket.gethostname(), help="Unique executor id" - ) + parser = argparse.ArgumentParser(prog="driver.py", description="Smallpond Driver", add_help=False) + parser.add_argument("mode", choices=["executor", "scheduler", "ray"], default="executor") + parser.add_argument("--exec_id", default=socket.gethostname(), help="Unique executor id") parser.add_argument("--job_id", type=str, help="Unique job id") - parser.add_argument( - "--job_time", type=float, help="Job create time (seconds since epoch)" - ) - parser.add_argument( - "--job_name", default="smallpond", help="Display name of the job" - ) + parser.add_argument("--job_time", type=float, help="Job create time (seconds since epoch)") + parser.add_argument("--job_name", default="smallpond", help="Display name of the job") parser.add_argument( "--job_priority", type=int, help="Job priority", ) parser.add_argument("--resource_group", type=str, help="Resource group") - parser.add_argument( - "--env_variables", nargs="*", default=[], help="Env variables for the job" - ) - parser.add_argument( - "--sidecars", nargs="*", default=[], help="Sidecars for the job" - ) - parser.add_argument( - "--tags", nargs="*", default=[], help="Tags for submitted platform task" - ) - parser.add_argument( - "--task_image", default="default", help="Container image of platform task" - ) - parser.add_argument( - "--python_venv", type=str, help="Python virtual env for the job" - ) + parser.add_argument("--env_variables", nargs="*", default=[], help="Env variables for the job") + parser.add_argument("--sidecars", nargs="*", default=[], help="Sidecars for the job") + parser.add_argument("--tags", nargs="*", default=[], help="Tags for submitted platform task") + parser.add_argument("--task_image", default="default", help="Container image of platform task") + parser.add_argument("--python_venv", type=str, help="Python virtual env for the job") parser.add_argument( "--data_root", type=str, @@ -257,9 +237,7 @@ class Driver(object): default="DEBUG", choices=log_level_choices, ) - parser.add_argument( - "--disable_log_rotation", action="store_true", help="Disable log rotation" - ) + parser.add_argument("--disable_log_rotation", action="store_true", help="Disable log rotation") parser.add_argument( "--output_path", help="Set the output directory of final results and all nodes that have output_name but no output_path specified", @@ -279,9 +257,7 @@ class Driver(object): def parse_arguments(self, args=None): if self.user_args is None or self.driver_args is None: - args_parser = argparse.ArgumentParser( - parents=[self.driver_args_parser, self.user_args_parser] - ) + args_parser = argparse.ArgumentParser(parents=[self.driver_args_parser, self.user_args_parser]) self.all_args = args_parser.parse_args(args) self.user_args, other_args = self.user_args_parser.parse_known_args(args) self.driver_args = self.driver_args_parser.parse_args(other_args) @@ -349,9 +325,7 @@ class Driver(object): DataFrame(sp, plan.root_node).compute() retval = True elif args.mode == "executor": - assert os.path.isfile( - args.runtime_ctx_path - ), f"cannot find runtime context: {args.runtime_ctx_path}" + assert os.path.isfile(args.runtime_ctx_path), f"cannot find runtime context: {args.runtime_ctx_path}" runtime_ctx: RuntimeContext = load(args.runtime_ctx_path) if runtime_ctx.bind_numa_node: @@ -371,9 +345,7 @@ class Driver(object): retval = run_executor(runtime_ctx, args.exec_id) elif args.mode == "scheduler": assert plan is not None - jobmgr = JobManager( - args.data_root, args.python_venv, args.task_image, args.platform - ) + jobmgr = JobManager(args.data_root, args.python_venv, args.task_image, args.platform) exec_plan = jobmgr.run( plan, job_id=args.job_id, diff --git a/smallpond/execution/executor.py b/smallpond/execution/executor.py index 890a69a..edd8763 100755 --- a/smallpond/execution/executor.py +++ b/smallpond/execution/executor.py @@ -35,9 +35,7 @@ class SimplePoolTask(object): def join(self, timeout=None): self.proc.join(timeout) if not self.ready() and timeout is not None: - logger.warning( - f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it" - ) + logger.warning(f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it") self.terminate() self.proc.join() @@ -45,17 +43,11 @@ class SimplePoolTask(object): return self.proc.pid and not self.proc.is_alive() def exitcode(self): - assert ( - self.ready() - ), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet" + assert self.ready(), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet" if self.stopping: - logger.info( - f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}" - ) + logger.info(f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}") elif self.proc.exitcode != 0: - logger.error( - f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}" - ) + logger.error(f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}") return self.proc.exitcode @@ -79,9 +71,7 @@ class SimplePool(object): def update_queue(self): self.running_tasks = [t for t in self.running_tasks if not t.ready()] tasks_to_run = self.queued_tasks[: self.pool_size - len(self.running_tasks)] - self.queued_tasks = self.queued_tasks[ - self.pool_size - len(self.running_tasks) : - ] + self.queued_tasks = self.queued_tasks[self.pool_size - len(self.running_tasks) :] for task in tasks_to_run: task.start() self.running_tasks += tasks_to_run @@ -97,9 +87,7 @@ class Executor(object): The task executor. """ - def __init__( - self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue - ) -> None: + def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue) -> None: self.ctx = ctx self.id = id self.wq = wq @@ -135,9 +123,7 @@ class Executor(object): def process_work(item: WorkItem, cq: WorkQueue): item.exec(cq) cq.push(item) - logger.info( - f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs" - ) + logger.info(f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs") logger.complete() # for test @@ -146,16 +132,12 @@ class Executor(object): # for test def skip_probes(self, epochs: int): - self.wq.push( - Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs) - ) + self.wq.push(Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs)) @logger.catch(reraise=True, message="executor terminated unexpectedly") def run(self) -> bool: mp.current_process().name = "ExecutorMainProcess" - logger.info( - f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}" - ) + logger.info(f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}") with SimplePool(self.ctx.usable_cpu_count + 1) as pool: retval = self.exec_loop(pool) @@ -173,34 +155,20 @@ class Executor(object): try: items = self.wq.pop(count=self.ctx.usable_cpu_count) except Exception as ex: - logger.opt(exception=ex).critical( - f"failed to pop from work queue: {self.wq}" - ) + logger.opt(exception=ex).critical(f"failed to pop from work queue: {self.wq}") self.running = False items = [] if not items: secs_quiet_period = time.time() - latest_probe_time - if ( - secs_quiet_period > self.ctx.secs_executor_probe_interval * 2 - and os.path.exists(self.ctx.job_status_path) - ): + if secs_quiet_period > self.ctx.secs_executor_probe_interval * 2 and os.path.exists(self.ctx.job_status_path): with open(self.ctx.job_status_path) as status_file: - if ( - status := status_file.read().strip() - ) and not status.startswith("running"): - logger.critical( - f"job scheduler already stopped: {status}, stopping executor" - ) + if (status := status_file.read().strip()) and not status.startswith("running"): + logger.critical(f"job scheduler already stopped: {status}, stopping executor") self.running = False break - if ( - secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2 - and not pytest_running() - ): - logger.critical( - f"no probe received for {secs_quiet_period:.1f} secs, stopping executor" - ) + if secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2 and not pytest_running(): + logger.critical(f"no probe received for {secs_quiet_period:.1f} secs, stopping executor") self.running = False break # no pending works, so wait a few seconds before checking results @@ -216,9 +184,7 @@ class Executor(object): if isinstance(item, StopWorkItem): running_work = self.running_works.get(item.work_to_stop, None) if running_work is None: - logger.debug( - f"cannot find {item.work_to_stop} in running works of {self.id}" - ) + logger.debug(f"cannot find {item.work_to_stop} in running works of {self.id}") self.cq.push(item) else: logger.info(f"stopping work: {item.work_to_stop}") @@ -250,20 +216,14 @@ class Executor(object): self.collect_finished_works() time.sleep(self.ctx.secs_wq_poll_interval) item._local_gpu = granted_gpus - logger.info( - f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }" - ) + logger.info(f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }") # enqueue work item to the pool self.running_works[item.key] = ( - pool.apply_async( - func=Executor.process_work, args=(item, self.cq), name=item.key - ), + pool.apply_async(func=Executor.process_work, args=(item, self.cq), name=item.key), item, ) - logger.info( - f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}..." - ) + logger.info(f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}...") # start to run works pool.update_queue() @@ -287,15 +247,11 @@ class Executor(object): work.join() if (exitcode := work.exitcode()) != 0: item.status = WorkStatus.CRASHED - item.exception = NonzeroExitCode( - f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}" - ) + item.exception = NonzeroExitCode(f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}") try: self.cq.push(item) except Exception as ex: - logger.opt(exception=ex).critical( - f"failed to push into completion queue: {self.cq}" - ) + logger.opt(exception=ex).critical(f"failed to push into completion queue: {self.cq}") self.running = False finished_works.append(item) @@ -304,9 +260,7 @@ class Executor(object): self.running_works.pop(item.key) if item._local_gpu: self.release_gpu(item._local_gpu) - logger.info( - f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }" - ) + logger.info(f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }") def acquire_gpu(self, quota: float) -> Dict[GPU, float]: """ @@ -336,6 +290,4 @@ class Executor(object): """ for gpu, quota in gpus.items(): self.local_gpus[gpu] += quota - assert ( - self.local_gpus[gpu] <= 1.0 - ), f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}" + assert self.local_gpus[gpu] <= 1.0, f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}" diff --git a/smallpond/execution/manager.py b/smallpond/execution/manager.py index a516f5e..6201e5b 100644 --- a/smallpond/execution/manager.py +++ b/smallpond/execution/manager.py @@ -26,13 +26,9 @@ class SchedStateExporter(Scheduler.StateObserver): if sched_state.large_runtime_state: logger.debug(f"pause exporting scheduler state") elif sched_state.num_local_running_works > 0: - logger.debug( - f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works" - ) + logger.debug(f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works") else: - dump( - sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True - ) + dump(sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True) sched_state.log_overall_progress() logger.debug(f"exported scheduler state to {self.sched_state_path}") @@ -97,12 +93,8 @@ class JobManager(object): bind_numa_node=False, enforce_memory_limit=False, share_log_analytics: Optional[bool] = None, - console_log_level: Literal[ - "CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE" - ] = "INFO", - file_log_level: Literal[ - "CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE" - ] = "DEBUG", + console_log_level: Literal["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] = "INFO", + file_log_level: Literal["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] = "DEBUG", disable_log_rotation=False, sched_state_observers: Optional[List[Scheduler.StateObserver]] = None, output_path: Optional[str] = None, @@ -111,11 +103,7 @@ class JobManager(object): logger.info(f"using platform: {self.platform}") job_id = JobId(hex=job_id or self.platform.default_job_id()) - job_time = ( - datetime.fromtimestamp(job_time) - if job_time is not None - else self.platform.default_job_time() - ) + job_time = datetime.fromtimestamp(job_time) if job_time is not None else self.platform.default_job_time() malloc_path = "" if memory_allocator == "system": @@ -131,19 +119,10 @@ class JobManager(object): arrow_default_malloc=memory_allocator, ).splitlines() env_overrides = env_overrides + (env_variables or []) - env_overrides = dict( - tuple(kv.strip().split("=", maxsplit=1)) - for kv in filter(None, env_overrides) - ) + env_overrides = dict(tuple(kv.strip().split("=", maxsplit=1)) for kv in filter(None, env_overrides)) - share_log_analytics = ( - share_log_analytics - if share_log_analytics is not None - else self.platform.default_share_log_analytics() - ) - shared_log_root = ( - self.platform.shared_log_root() if share_log_analytics else None - ) + share_log_analytics = share_log_analytics if share_log_analytics is not None else self.platform.default_share_log_analytics() + shared_log_root = self.platform.shared_log_root() if share_log_analytics else None runtime_ctx = RuntimeContext( job_id, @@ -167,9 +146,7 @@ class JobManager(object): **kwargs, ) runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True) - logger.info( - f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}" - ) + logger.info(f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}") dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True) logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}") @@ -178,13 +155,9 @@ class JobManager(object): logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}") plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png") - logger.info( - f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png" - ) + logger.info(f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png") - exec_plan = Planner(runtime_ctx).create_exec_plan( - plan, manifest_only_final_results - ) + exec_plan = Planner(runtime_ctx).create_exec_plan(plan, manifest_only_final_results) dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True) logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}") @@ -229,9 +202,7 @@ class JobManager(object): sched_state_observers.insert(0, sched_state_exporter) if os.path.exists(runtime_ctx.sched_state_path): - logger.warning( - f"loading scheduler state from: {runtime_ctx.sched_state_path}" - ) + logger.warning(f"loading scheduler state from: {runtime_ctx.sched_state_path}") scheduler: Scheduler = load(runtime_ctx.sched_state_path) scheduler.sched_epoch += 1 scheduler.sched_state_observers = sched_state_observers diff --git a/smallpond/execution/scheduler.py b/smallpond/execution/scheduler.py index f64135a..a4b8d56 100644 --- a/smallpond/execution/scheduler.py +++ b/smallpond/execution/scheduler.py @@ -54,9 +54,7 @@ class ExecutorState(Enum): class RemoteExecutor(object): - def __init__( - self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0 - ) -> None: + def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0) -> None: self.ctx = ctx self.id = id self.wq = wq @@ -79,9 +77,7 @@ state={self.state}, probe={self.last_acked_probe}" return f"RemoteExecutor({self.id}):{self.state}" @staticmethod - def create( - ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0 - ) -> "RemoteExecutor": + def create(ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0) -> "RemoteExecutor": wq = WorkQueueOnFilesystem(os.path.join(queue_dir, "wq")) cq = WorkQueueOnFilesystem(os.path.join(queue_dir, "cq")) return RemoteExecutor(ctx, id, wq, cq, init_epoch) @@ -173,9 +169,7 @@ state={self.state}, probe={self.last_acked_probe}" return self.cpu_count - self.cpu_count // 16 def add_running_work(self, item: WorkItem): - assert ( - item.key not in self.running_works - ), f"duplicate work item assigned to {repr(self)}: {item.key}" + assert item.key not in self.running_works, f"duplicate work item assigned to {repr(self)}: {item.key}" self.running_works[item.key] = item self._allocated_cpus += item.cpu_limit self._allocated_gpus += item.gpu_limit @@ -219,9 +213,7 @@ state={self.state}, probe={self.last_acked_probe}" def push(self, item: WorkItem, buffering=False) -> bool: if item.key in self.running_works: - logger.warning( - f"work item {item.key} already exists in running works of {self}" - ) + logger.warning(f"work item {item.key} already exists in running works of {self}") return False item.start_time = time.time() item.exec_id = self.id @@ -250,9 +242,7 @@ state={self.state}, probe={self.last_acked_probe}" elif num_missed_probes > self.ctx.max_num_missed_probes: if self.state != ExecutorState.FAIL: self.state = ExecutorState.FAIL - logger.error( - f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}" - ) + logger.error(f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}") return True elif self.state == ExecutorState.STOPPING: if self.stop_request_acked: @@ -277,9 +267,7 @@ state={self.state}, probe={self.last_acked_probe}" class LocalExecutor(RemoteExecutor): - def __init__( - self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue - ) -> None: + def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue) -> None: super().__init__(ctx, id, wq, cq) self.work = None self.running = False @@ -321,9 +309,7 @@ class LocalExecutor(RemoteExecutor): if item.gpu_limit > 0: assert len(local_gpus) > 0 item._local_gpu = local_gpus[0] - logger.info( - f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}" - ) + logger.info(f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}") item = copy.copy(item) item.exec() @@ -368,9 +354,7 @@ class Scheduler(object): self.callback = callback def __repr__(self) -> str: - return ( - repr(self.callback) if self.callback is not None else super().__repr__() - ) + return repr(self.callback) if self.callback is not None else super().__repr__() __str__ = __repr__ @@ -403,9 +387,7 @@ class Scheduler(object): self.stop_executor_on_failure = stop_executor_on_failure self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom self.remove_output_root = remove_output_root - self.sched_state_observers: List[Scheduler.StateObserver] = ( - sched_state_observers or [] - ) + self.sched_state_observers: List[Scheduler.StateObserver] = sched_state_observers or [] self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2 # task states self.local_queue: List[Task] = [] @@ -414,11 +396,7 @@ class Scheduler(object): self.scheduled_tasks: Dict[TaskRuntimeId, Task] = OrderedDict() self.finished_tasks: Dict[TaskRuntimeId, Task] = OrderedDict() self.succeeded_tasks: Dict[str, Task] = OrderedDict() - self.nontrivial_tasks = dict( - (key, task) - for (key, task) in self.tasks.items() - if not task.exec_on_scheduler - ) + self.nontrivial_tasks = dict((key, task) for (key, task) in self.tasks.items() if not task.exec_on_scheduler) self.succeeded_nontrivial_tasks: Dict[str, Task] = OrderedDict() # executor pool self.local_executor = LocalExecutor.create(self.ctx, "localhost") @@ -463,18 +441,11 @@ class Scheduler(object): @property def running_works(self) -> Iterable[WorkItem]: - return ( - work - for executor in (self.alive_executors + self.local_executors) - for work in executor.running_works.values() - ) + return (work for executor in (self.alive_executors + self.local_executors) for work in executor.running_works.values()) @property def num_running_works(self) -> int: - return sum( - len(executor.running_works) - for executor in (self.alive_executors + self.local_executors) - ) + return sum(len(executor.running_works) for executor in (self.alive_executors + self.local_executors)) @property def num_local_running_works(self) -> int: @@ -489,11 +460,7 @@ class Scheduler(object): @property def pending_nontrivial_tasks(self) -> Dict[str, Task]: - return dict( - (key, task) - for key, task in self.nontrivial_tasks.items() - if key not in self.succeeded_nontrivial_tasks - ) + return dict((key, task) for key, task in self.nontrivial_tasks.items() if key not in self.succeeded_nontrivial_tasks) @property def num_pending_nontrivial_tasks(self) -> int: @@ -504,33 +471,20 @@ class Scheduler(object): @property def succeeded_task_ids(self) -> Set[TaskRuntimeId]: - return set( - TaskRuntimeId(task.id, task.sched_epoch, task.retry_count) - for task in self.succeeded_tasks.values() - ) + return set(TaskRuntimeId(task.id, task.sched_epoch, task.retry_count) for task in self.succeeded_tasks.values()) @property def abandoned_tasks(self) -> List[Task]: succeeded_task_ids = self.succeeded_task_ids - return [ - task - for task in {**self.scheduled_tasks, **self.finished_tasks}.values() - if task.runtime_id not in succeeded_task_ids - ] + return [task for task in {**self.scheduled_tasks, **self.finished_tasks}.values() if task.runtime_id not in succeeded_task_ids] @cached_property def remote_executors(self) -> List[RemoteExecutor]: - return [ - executor - for executor in self.available_executors.values() - if not executor.local - ] + return [executor for executor in self.available_executors.values() if not executor.local] @cached_property def local_executors(self) -> List[RemoteExecutor]: - return [ - executor for executor in self.available_executors.values() if executor.local - ] + return [executor for executor in self.available_executors.values() if executor.local] @cached_property def working_executors(self) -> List[RemoteExecutor]: @@ -592,10 +546,7 @@ class Scheduler(object): def start_speculative_execution(self): for executor in self.working_executors: for idx, item in enumerate(executor.running_works.values()): - aggressive_retry = ( - self.aggressive_speculative_exec - and len(self.good_executors) >= self.ctx.num_executors - ) + aggressive_retry = self.aggressive_speculative_exec and len(self.good_executors) >= self.ctx.num_executors short_sched_queue = len(self.sched_queue) < len(self.good_executors) if ( isinstance(item, Task) @@ -603,8 +554,7 @@ class Scheduler(object): and item.allow_speculative_exec and item.retry_count < self.max_retry_count and item.retry_count == self.tasks[item.key].retry_count - and (logical_node := self.logical_nodes.get(item.node_id, None)) - is not None + and (logical_node := self.logical_nodes.get(item.node_id, None)) is not None ): perf_stats = logical_node.get_perf_stats("elapsed wall time (secs)") if perf_stats is not None and perf_stats.cnt >= 20: @@ -639,12 +589,8 @@ class Scheduler(object): if entry.is_dir(): _, exec_id = os.path.split(entry.path) if exec_id not in self.available_executors: - self.available_executors[exec_id] = RemoteExecutor.create( - self.ctx, exec_id, entry.path, self.probe_epoch - ) - logger.info( - f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}" - ) + self.available_executors[exec_id] = RemoteExecutor.create(self.ctx, exec_id, entry.path, self.probe_epoch) + logger.info(f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}") self.clear_cached_executor_lists() # start a new probe epoch self.last_executor_probe_time = time.time() @@ -668,9 +614,7 @@ class Scheduler(object): item.status = WorkStatus.EXEC_FAILED item.finish_time = time.time() if isinstance(item, Task) and item.key not in self.succeeded_tasks: - logger.warning( - f"reschedule {repr(item)} on failed executor: {repr(executor)}" - ) + logger.warning(f"reschedule {repr(item)} on failed executor: {repr(executor)}") self.try_enqueue(self.get_retry_task(item.key)) if any(executor_state_changed): @@ -690,9 +634,7 @@ class Scheduler(object): # remove the reference to input deps task.input_deps = {dep_key: None for dep_key in task.input_deps} # feed input datasets - task.input_datasets = [ - self.succeeded_tasks[dep_key].output for dep_key in task.input_deps - ] + task.input_datasets = [self.succeeded_tasks[dep_key].output for dep_key in task.input_deps] task.sched_epoch = self.sched_epoch return task @@ -713,9 +655,7 @@ class Scheduler(object): task.dataset = finished_task.dataset def get_runnable_tasks(self, finished_task: Task) -> Iterable[Task]: - assert ( - finished_task.status == WorkStatus.SUCCEED - ), f"task not succeeded: {finished_task}" + assert finished_task.status == WorkStatus.SUCCEED, f"task not succeeded: {finished_task}" for output_key in finished_task.output_deps: output_dep = self.tasks[output_key] if all(key in self.succeeded_tasks for key in output_dep.input_deps): @@ -730,14 +670,8 @@ class Scheduler(object): for executor in self.remote_executors: running_task = executor.running_works.get(task_key, None) if running_task is not None: - logger.info( - f"try to stop {repr(running_task)} running on {repr(executor)}" - ) - executor.wq.push( - StopWorkItem( - f".StopWorkItem-{repr(running_task)}", running_task.key - ) - ) + logger.info(f"try to stop {repr(running_task)} running on {repr(executor)}") + executor.wq.push(StopWorkItem(f".StopWorkItem-{repr(running_task)}", running_task.key)) def try_relax_memory_limit(self, task: Task, executor: RemoteExecutor) -> bool: if task.memory_limit >= executor.memory_size: @@ -745,9 +679,7 @@ class Scheduler(object): return False relaxed_memory_limit = min(executor.memory_size, task.memory_limit * 2) task._memory_boost = relaxed_memory_limit / task._memory_limit - logger.warning( - f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ..." - ) + logger.warning(f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ...") return True def try_boost_resource(self, item: WorkItem, executor: RemoteExecutor): @@ -777,9 +709,7 @@ class Scheduler(object): if item._cpu_limit < boost_cpu or item._memory_limit < boost_mem: item._cpu_boost = boost_cpu / item._cpu_limit item._memory_boost = boost_mem / item._memory_limit - logger.info( - f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB" - ) + logger.info(f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB") def get_retry_task(self, key: str) -> Task: task = self.tasks[key] @@ -794,9 +724,7 @@ class Scheduler(object): remove_path(self.ctx.staging_root) if abandoned_tasks := self.abandoned_tasks: - logger.info( - f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..." - ) + logger.info(f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ...") assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)) @logger.catch(reraise=pytest_running(), message="failed to export task metrics") @@ -825,15 +753,9 @@ class Scheduler(object): buffering=32 * MB, ) - task_props = arrow.array( - pristine_attrs_dict(task) for task in self.finished_tasks.values() - ) - partition_infos = arrow.array( - task.partition_infos_as_dict for task in self.finished_tasks.values() - ) - perf_metrics = arrow.array( - dict(task.perf_metrics) for task in self.finished_tasks.values() - ) + task_props = arrow.array(pristine_attrs_dict(task) for task in self.finished_tasks.values()) + partition_infos = arrow.array(task.partition_infos_as_dict for task in self.finished_tasks.values()) + perf_metrics = arrow.array(dict(task.perf_metrics) for task in self.finished_tasks.values()) task_metrics = arrow.Table.from_arrays( [task_props, partition_infos, perf_metrics], names=["task_props", "partition_infos", "perf_metrics"], @@ -862,12 +784,7 @@ class Scheduler(object): [ dict( task=repr(task), - node=( - repr(node) - if (node := self.logical_nodes.get(task.node_id, None)) - is not None - else "StandaloneTasks" - ), + node=(repr(node) if (node := self.logical_nodes.get(task.node_id, None)) is not None else "StandaloneTasks"), status=str(task.status), executor=task.exec_id, start_time=datetime.fromtimestamp(task.start_time), @@ -925,23 +842,16 @@ class Scheduler(object): fig_filename, _ = fig_title.split(" - ", maxsplit=1) fig_filename += ".html" fig_path = os.path.join(self.ctx.log_root, fig_filename) - fig.update_yaxes( - autorange="reversed" - ) # otherwise tasks are listed from the bottom up + fig.update_yaxes(autorange="reversed") # otherwise tasks are listed from the bottom up fig.update_traces(marker_line_color="black", marker_line_width=1, opacity=1) - fig.write_html( - fig_path, include_plotlyjs="cdn" if pytest_running() else True - ) + fig.write_html(fig_path, include_plotlyjs="cdn" if pytest_running() else True) if self.ctx.shared_log_root: shutil.copy(fig_path, self.ctx.shared_log_root) logger.debug(f"exported timeline figure to {fig_path}") def notify_state_observers(self, force_notify=False) -> bool: secs_since_last_state_notify = time.time() - self.last_state_notify_time - if ( - force_notify - or secs_since_last_state_notify >= self.secs_state_notify_interval - ): + if force_notify or secs_since_last_state_notify >= self.secs_state_notify_interval: self.last_state_notify_time = time.time() for observer in self.sched_state_observers: if force_notify or observer.enabled: @@ -949,14 +859,10 @@ class Scheduler(object): observer.update(self) elapsed_time = time.time() - start_time if elapsed_time >= self.ctx.secs_executor_probe_interval / 2: - self.secs_state_notify_interval = ( - self.ctx.secs_executor_probe_timeout - ) + self.secs_state_notify_interval = self.ctx.secs_executor_probe_timeout if elapsed_time >= self.ctx.secs_executor_probe_interval: observer.enabled = False - logger.warning( - f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}" - ) + logger.warning(f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}") return True else: return False @@ -984,9 +890,7 @@ class Scheduler(object): def run(self) -> bool: mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}" - logger.info( - f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}" - ) + logger.info(f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}") perf_profile = None if self.ctx.enable_profiling: @@ -1001,48 +905,30 @@ class Scheduler(object): self.prioritize_retry |= self.sched_epoch > 0 if self.local_queue or self.sched_queue: - pending_tasks = [ - item - for item in self.local_queue + self.sched_queue - if isinstance(item, Task) - ] + pending_tasks = [item for item in self.local_queue + self.sched_queue if isinstance(item, Task)] self.local_queue.clear() self.sched_queue.clear() - logger.info( - f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..." - ) + logger.info(f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ...") self.try_enqueue(pending_tasks) if self.sched_epoch == 0: leaf_tasks = self.exec_plan.leaves - logger.info( - f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..." - ) + logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...") self.try_enqueue(leaf_tasks) self.log_overall_progress() while (num_finished_tasks := self.process_finished_tasks(pool)) > 0: - logger.info( - f"processed {num_finished_tasks} finished tasks during startup" - ) + logger.info(f"processed {num_finished_tasks} finished tasks during startup") self.log_overall_progress() - earlier_running_tasks = [ - item for item in self.running_works if isinstance(item, Task) - ] + earlier_running_tasks = [item for item in self.running_works if isinstance(item, Task)] if earlier_running_tasks: - logger.info( - f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..." - ) + logger.info(f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ...") self.try_enqueue(earlier_running_tasks) self.suspend_good_executors() - self.add_state_observer( - Scheduler.StateObserver(Scheduler.log_current_status) - ) - self.add_state_observer( - Scheduler.StateObserver(Scheduler.export_timeline_figs) - ) + self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status)) + self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs)) self.notify_state_observers(force_notify=True) try: @@ -1063,14 +949,10 @@ class Scheduler(object): if self.success: self.clean_temp_files(pool) logger.success(f"final output path: {self.exec_plan.final_output_path}") - logger.info( - f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}" - ) + logger.info(f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}") if perf_profile is not None: - logger.debug( - f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}" - ) + logger.debug(f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}") logger.info(f"scheduler of job {self.ctx.job_id} exits") logger.complete() @@ -1082,20 +964,14 @@ class Scheduler(object): task = self.copy_task_for_execution(task) if task.key in self.succeeded_tasks: logger.debug(f"task {repr(task)} already succeeded, skipping") - self.try_enqueue( - self.get_runnable_tasks(self.succeeded_tasks[task.key]) - ) + self.try_enqueue(self.get_runnable_tasks(self.succeeded_tasks[task.key])) continue if task.runtime_id in self.scheduled_tasks: logger.debug(f"task {repr(task)} already scheduled, skipping") continue # save enqueued task self.scheduled_tasks[task.runtime_id] = task - if ( - self.standalone_mode - or task.exec_on_scheduler - or task.skip_when_any_input_empty - ): + if self.standalone_mode or task.exec_on_scheduler or task.skip_when_any_input_empty: self.local_queue.append(task) else: self.sched_queue.append(task) @@ -1114,34 +990,20 @@ class Scheduler(object): if self.local_queue: assert self.local_executor.alive - logger.info( - f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ..." - ) - self.local_queue = [ - item - for item in self.local_queue - if not self.local_executor.push(item, buffering=True) - ] + logger.info(f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ...") + self.local_queue = [item for item in self.local_queue if not self.local_executor.push(item, buffering=True)] self.local_executor.flush() has_progress |= self.dispatch_tasks(pool) > 0 - if len( - self.sched_queue - ) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors): + if len(self.sched_queue) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors): for executor in self.good_executors: if executor.idle: - logger.info( - f"{len(self.good_executors)} remote executors running, stopping {executor}" - ) + logger.info(f"{len(self.good_executors)} remote executors running, stopping {executor}") executor.stop() break - if ( - len(self.sched_queue) == 0 - and len(self.local_queue) == 0 - and self.num_running_works == 0 - ): + if len(self.sched_queue) == 0 and len(self.local_queue) == 0 and self.num_running_works == 0: self.log_overall_progress() assert ( self.num_pending_tasks == 0 @@ -1166,29 +1028,13 @@ class Scheduler(object): def dispatch_tasks(self, pool: ThreadPoolExecutor): # sort pending tasks - item_sort_key = ( - (lambda item: (-item.retry_count, item.id)) - if self.prioritize_retry - else (lambda item: (item.retry_count, item.id)) - ) + item_sort_key = (lambda item: (-item.retry_count, item.id)) if self.prioritize_retry else (lambda item: (item.retry_count, item.id)) items_sorted_by_node_id = sorted(self.sched_queue, key=lambda t: t.node_id) - items_group_by_node_id = itertools.groupby( - items_sorted_by_node_id, key=lambda t: t.node_id - ) - sorted_item_groups = [ - sorted(items, key=item_sort_key) for _, items in items_group_by_node_id - ] - self.sched_queue = [ - item - for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None) - for item in batch - if item is not None - ] + items_group_by_node_id = itertools.groupby(items_sorted_by_node_id, key=lambda t: t.node_id) + sorted_item_groups = [sorted(items, key=item_sort_key) for _, items in items_group_by_node_id] + self.sched_queue = [item for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None) for item in batch if item is not None] - final_phase = ( - self.num_pending_nontrivial_tasks - self.num_running_works - <= len(self.good_executors) * 2 - ) + final_phase = self.num_pending_nontrivial_tasks - self.num_running_works <= len(self.good_executors) * 2 num_dispatched_tasks = 0 unassigned_tasks = [] @@ -1196,42 +1042,31 @@ class Scheduler(object): first_item = self.sched_queue[0] # assign tasks to executors in round-robin fashion - usable_executors = [ - executor for executor in self.good_executors if not executor.busy - ] - for executor in sorted( - usable_executors, key=lambda exec: len(exec.running_works) - ): + usable_executors = [executor for executor in self.good_executors if not executor.busy] + for executor in sorted(usable_executors, key=lambda exec: len(exec.running_works)): if not self.sched_queue: break item = self.sched_queue[0] if item._memory_limit is None: - item._memory_limit = np.int64( - executor.memory_size * item._cpu_limit // executor.cpu_count - ) + item._memory_limit = np.int64(executor.memory_size * item._cpu_limit // executor.cpu_count) if item.key in self.succeeded_tasks: logger.debug(f"task {repr(item)} already succeeded, skipping") self.sched_queue.pop(0) - self.try_enqueue( - self.get_runnable_tasks(self.succeeded_tasks[item.key]) - ) + self.try_enqueue(self.get_runnable_tasks(self.succeeded_tasks[item.key])) elif ( len(executor.running_works) < executor.max_running_works and executor.allocated_cpus + item.cpu_limit <= executor.cpu_count and executor.allocated_gpus + item.gpu_limit <= executor.gpu_count - and executor.allocated_memory + item.memory_limit - <= executor.memory_size + and executor.allocated_memory + item.memory_limit <= executor.memory_size and item.key not in executor.running_works ): if final_phase: self.try_boost_resource(item, executor) # push to wq of executor but not flushed yet executor.push(item, buffering=True) - logger.info( - f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}" - ) + logger.info(f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}") self.sched_queue.pop(0) num_dispatched_tasks += 1 @@ -1242,55 +1077,35 @@ class Scheduler(object): self.sched_queue.extend(unassigned_tasks) # flush the buffered work items into wq - assert all( - pool.map(RemoteExecutor.flush, self.good_executors) - ), f"failed to flush work queues" + assert all(pool.map(RemoteExecutor.flush, self.good_executors)), f"failed to flush work queues" return num_dispatched_tasks def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int: pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values()) num_finished_tasks = 0 - for executor, finished_tasks in zip( - self.available_executors.values(), pop_results - ): + for executor, finished_tasks in zip(self.available_executors.values(), pop_results): for finished_task in finished_tasks: assert isinstance(finished_task, Task) - scheduled_task = self.scheduled_tasks.get( - finished_task.runtime_id, None - ) + scheduled_task = self.scheduled_tasks.get(finished_task.runtime_id, None) if scheduled_task is None: - logger.info( - f"task not initiated by current scheduler: {finished_task}" - ) + logger.info(f"task not initiated by current scheduler: {finished_task}") if finished_task.status != WorkStatus.SUCCEED and ( - missing_inputs := [ - key - for key in finished_task.input_deps - if key not in self.succeeded_tasks - ] + missing_inputs := [key for key in finished_task.input_deps if key not in self.succeeded_tasks] ): - logger.info( - f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}" - ) + logger.info(f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}") continue if finished_task.status == WorkStatus.INCOMPLETE: - logger.trace( - f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}" - ) - self.tasks[finished_task.key].runtime_state = ( - finished_task.runtime_state - ) + logger.trace(f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}") + self.tasks[finished_task.key].runtime_state = finished_task.runtime_state continue prior_task = self.finished_tasks.get(finished_task.runtime_id, None) if prior_task is not None: - logger.info( - f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}" - ) + logger.info(f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}") continue else: self.finished_tasks[finished_task.runtime_id] = finished_task @@ -1298,30 +1113,22 @@ class Scheduler(object): succeeded_task = self.succeeded_tasks.get(finished_task.key, None) if succeeded_task is not None: - logger.info( - f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}" - ) + logger.info(f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}") continue if finished_task.status in (WorkStatus.FAILED, WorkStatus.CRASHED): - logger.warning( - f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}" - ) + logger.warning(f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}") finished_task.dump() task = self.tasks[finished_task.key] task.fail_count += 1 if task.fail_count > self.max_fail_count: - logger.critical( - f"task failed too many times: {finished_task}, stopping ..." - ) + logger.critical(f"task failed too many times: {finished_task}, stopping ...") self.stop_executors() self.sched_running = False - if not executor.local and finished_task.oom( - self.nonzero_exitcode_as_oom - ): + if not executor.local and finished_task.oom(self.nonzero_exitcode_as_oom): if task._memory_limit is None: task._memory_limit = finished_task._memory_limit self.try_relax_memory_limit(task, executor) @@ -1332,9 +1139,7 @@ class Scheduler(object): self.try_enqueue(self.get_retry_task(finished_task.key)) else: - assert ( - finished_task.status == WorkStatus.SUCCEED - ), f"unexpected task status: {finished_task}" + assert finished_task.status == WorkStatus.SUCCEED, f"unexpected task status: {finished_task}" logger.log( "TRACE" if finished_task.exec_on_scheduler else "INFO", "task succeeded on {}: {}", @@ -1344,9 +1149,7 @@ class Scheduler(object): self.succeeded_tasks[finished_task.key] = finished_task if not finished_task.exec_on_scheduler: - self.succeeded_nontrivial_tasks[finished_task.key] = ( - finished_task - ) + self.succeeded_nontrivial_tasks[finished_task.key] = finished_task # stop the redundant retries of finished task self.stop_running_tasks(finished_task.key) @@ -1356,9 +1159,7 @@ class Scheduler(object): if finished_task.id == self.exec_plan.root_task.id: self.sched_queue = [] self.stop_executors() - logger.success( - f"all tasks completed, root task: {finished_task}" - ) + logger.success(f"all tasks completed, root task: {finished_task}") logger.success( f"{len(self.succeeded_tasks)} succeeded tasks, success: {self.success}, elapsed time: {self.elapsed_time:.3f} secs" ) diff --git a/smallpond/execution/task.py b/smallpond/execution/task.py index 43d8d22..d36bf00 100644 --- a/smallpond/execution/task.py +++ b/smallpond/execution/task.py @@ -126,11 +126,7 @@ class TaskRuntimeId: return f"{self.id}.{self.epoch}.{self.retry}" -class PerfStats( - namedtuple( - "PerfStats", ("cnt", "sum", "min", "max", "avg", "p50", "p75", "p95", "p99") - ) -): +class PerfStats(namedtuple("PerfStats", ("cnt", "sum", "min", "max", "avg", "p50", "p75", "p95", "p99"))): """ Performance statistics for a task. """ @@ -179,9 +175,7 @@ class RuntimeContext(object): self.data_root = data_root self.next_task_id = 0 self.num_executors = num_executors - self.random_seed: int = random_seed or int.from_bytes( - os.urandom(RAND_SEED_BYTE_LEN), byteorder=sys.byteorder - ) + self.random_seed: int = random_seed or int.from_bytes(os.urandom(RAND_SEED_BYTE_LEN), byteorder=sys.byteorder) self.env_overrides = env_overrides or {} self.bind_numa_node = bind_numa_node self.numa_node_id: Optional[int] = None @@ -198,11 +192,7 @@ class RuntimeContext(object): self.remove_empty_parquet = remove_empty_parquet self.skip_task_with_empty_input = skip_task_with_empty_input - self.shared_log_root = ( - os.path.join(shared_log_root, self.job_root_dirname) - if shared_log_root - else None - ) + self.shared_log_root = os.path.join(shared_log_root, self.job_root_dirname) if shared_log_root else None self.console_log_level = console_log_level self.file_log_level = file_log_level self.disable_log_rotation = disable_log_rotation @@ -270,20 +260,12 @@ class RuntimeContext(object): @property def available_memory(self): available_memory = psutil.virtual_memory().available - return ( - available_memory // self.numa_node_count - if self.bind_numa_node - else available_memory - ) + return available_memory // self.numa_node_count if self.bind_numa_node else available_memory @property def total_memory(self): total_memory = psutil.virtual_memory().total - return ( - total_memory // self.numa_node_count - if self.bind_numa_node - else total_memory - ) + return total_memory // self.numa_node_count if self.bind_numa_node else total_memory @property def usable_cpu_count(self): @@ -300,11 +282,7 @@ class RuntimeContext(object): def get_local_gpus(self) -> List[GPUtil.GPU]: gpus = GPUtil.getGPUs() gpus_on_node = split_into_rows(gpus, self.numa_node_count) - return ( - gpus_on_node[self.numa_node_id] - if self.bind_numa_node and self.numa_node_id is not None - else gpus - ) + return gpus_on_node[self.numa_node_id] if self.bind_numa_node and self.numa_node_id is not None else gpus @property def usable_gpu_count(self): @@ -359,14 +337,10 @@ class RuntimeContext(object): ld_library_path = os.getenv("LD_LIBRARY_PATH", "") py_library_path = os.path.join(sys.exec_prefix, "lib") if py_library_path not in ld_library_path: - env_overrides["LD_LIBRARY_PATH"] = ":".join( - [py_library_path, ld_library_path] - ) + env_overrides["LD_LIBRARY_PATH"] = ":".join([py_library_path, ld_library_path]) for key, val in env_overrides.items(): if (old := os.getenv(key, None)) is not None: - logger.info( - f"overwrite environment variable '{key}': '{old}' -> '{val}'" - ) + logger.info(f"overwrite environment variable '{key}': '{old}' -> '{val}'") else: logger.info(f"set environment variable '{key}': '{val}'") os.environ[key] = val @@ -375,11 +349,7 @@ class RuntimeContext(object): ) def _init_logs(self, exec_id: str, capture_stdout_stderr: bool = False) -> None: - log_rotation = ( - {"rotation": "100 MB", "retention": 5} - if not self.disable_log_rotation - else {} - ) + log_rotation = {"rotation": "100 MB", "retention": 5} if not self.disable_log_rotation else {} log_file_paths = [os.path.join(self.log_root, f"{exec_id}.log")] user_log_only = {"": self.file_log_level, "smallpond": False} user_log_path = os.path.join(self.log_root, f"{exec_id}-user.log") @@ -425,9 +395,7 @@ class RuntimeContext(object): ) logger.info(f"initialized user logging to file: {user_log_path}") # intercept messages from logging - logging.basicConfig( - handlers=[InterceptHandler()], level=logging.INFO, force=True - ) + logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO, force=True) # capture stdout as INFO level # https://loguru.readthedocs.io/en/stable/resources/recipes.html#capturing-standard-stdout-stderr-and-warnings if capture_stdout_stderr: @@ -459,9 +427,7 @@ class RuntimeContext(object): class Probe(WorkItem): - def __init__( - self, ctx: RuntimeContext, key: str, epoch: int, epochs_to_skip=0 - ) -> None: + def __init__(self, ctx: RuntimeContext, key: str, epoch: int, epochs_to_skip=0) -> None: super().__init__(key, cpu_limit=0, gpu_limit=0, memory_limit=0) self.ctx = ctx self.epoch = epoch @@ -483,14 +449,10 @@ class Probe(WorkItem): return self.epoch < other.epoch def run(self) -> bool: - self.cpu_percent = psutil.cpu_percent( - interval=min(self.ctx.secs_executor_probe_interval / 2, 3) - ) + self.cpu_percent = psutil.cpu_percent(interval=min(self.ctx.secs_executor_probe_interval / 2, 3)) self.total_memory = self.ctx.usable_memory_size self.available_memory = self.ctx.available_memory - self.resource_low = ( - self.cpu_percent >= 80.0 or self.available_memory < self.total_memory // 16 - ) + self.resource_low = self.cpu_percent >= 80.0 or self.available_memory < self.total_memory // 16 self.cpu_count = self.ctx.usable_cpu_count self.gpu_count = self.ctx.usable_gpu_count logger.info("resource usage: {}", self) @@ -511,9 +473,7 @@ class PartitionInfo(object): "dimension", ) - def __init__( - self, index: int = 0, npartitions: int = 1, dimension: str = toplevel_dimension - ) -> None: + def __init__(self, index: int = 0, npartitions: int = 1, dimension: str = toplevel_dimension) -> None: self.index = index self.npartitions = npartitions self.dimension = dimension @@ -587,18 +547,11 @@ class Task(WorkItem): memory_limit: Optional[int] = None, ) -> None: assert isinstance(input_deps, Iterable), f"{input_deps} is not iterable" - assert all( - isinstance(dep, Task) for dep in input_deps - ), f"not every element of {input_deps} is a task" - assert isinstance( - partition_infos, Iterable - ), f"{partition_infos} is not iterable" - assert all( - isinstance(info, PartitionInfo) for info in partition_infos - ), f"not every element of {partition_infos} is a partition info" + assert all(isinstance(dep, Task) for dep in input_deps), f"not every element of {input_deps} is a task" + assert isinstance(partition_infos, Iterable), f"{partition_infos} is not iterable" + assert all(isinstance(info, PartitionInfo) for info in partition_infos), f"not every element of {partition_infos} is a partition info" assert any( - info.dimension == PartitionInfo.toplevel_dimension - for info in partition_infos + info.dimension == PartitionInfo.toplevel_dimension for info in partition_infos ), f"cannot find toplevel partition dimension: {partition_infos}" assert cpu_limit > 0, f"cpu_limit should be greater than zero" self.ctx = ctx @@ -611,12 +564,8 @@ class Task(WorkItem): self.perf_metrics: Dict[str, np.int64] = defaultdict(np.int64) self.perf_profile = None self._partition_infos = sorted(partition_infos) or [] - assert len(self.partition_dims) == len( - set(self.partition_dims) - ), f"found duplicate partition dimensions: {self.partition_dims}" - super().__init__( - f"{self.__class__.__name__}-{self.id}", cpu_limit, gpu_limit, memory_limit - ) + assert len(self.partition_dims) == len(set(self.partition_dims)), f"found duplicate partition dimensions: {self.partition_dims}" + super().__init__(f"{self.__class__.__name__}-{self.id}", cpu_limit, gpu_limit, memory_limit) self.output_name = output_name self.output_root = output_path self._temp_output = output_name is None and output_path is None @@ -663,11 +612,7 @@ class Task(WorkItem): @property def _pristine_attrs(self) -> Set[str]: """All attributes in __slots__.""" - return set( - itertools.chain.from_iterable( - getattr(cls, "__slots__", []) for cls in type(self).__mro__ - ) - ) + return set(itertools.chain.from_iterable(getattr(cls, "__slots__", []) for cls in type(self).__mro__)) @property def partition_infos(self) -> Tuple[PartitionInfo]: @@ -685,20 +630,11 @@ class Task(WorkItem): ("__job_id__", str(self.ctx.job_id)), ("__job_root__", self.ctx.job_root), ] - partition_infos = [ - (info.dimension, str(info.index)) - for info in self._partition_infos + (extra_partitions or []) - ] - return dict( - (PARQUET_METADATA_KEY_PREFIX + k, v) - for k, v in task_infos + partition_infos - ) + partition_infos = [(info.dimension, str(info.index)) for info in self._partition_infos + (extra_partitions or [])] + return dict((PARQUET_METADATA_KEY_PREFIX + k, v) for k, v in task_infos + partition_infos) def parquet_kv_metadata_bytes(self, extra_partitions: List[PartitionInfo] = None): - return dict( - (k.encode("utf-8"), v.encode("utf-8")) - for k, v in self.parquet_kv_metadata_str(extra_partitions).items() - ) + return dict((k.encode("utf-8"), v.encode("utf-8")) for k, v in self.parquet_kv_metadata_str(extra_partitions).items()) @property def partition_dims(self): @@ -741,19 +677,11 @@ class Task(WorkItem): """ If the task has a special output directory, its runtime output directory will be under it. """ - return ( - None - if self.output_root is None - else os.path.join(self.output_root, ".staging") - ) + return None if self.output_root is None else os.path.join(self.output_root, ".staging") @property def _final_output_root(self): - return ( - self.ctx.staging_root - if self.temp_output - else (self.output_root or self.ctx.output_root) - ) + return self.ctx.staging_root if self.temp_output else (self.output_root or self.ctx.output_root) @property def _runtime_output_root(self): @@ -816,9 +744,7 @@ class Task(WorkItem): The path of an empty file that is used to determine if the task has been started. Only used by the ray executor. """ - return os.path.join( - self.ctx.started_task_dir, f"{self.node_id}.{self.key}.{self.retry_count}" - ) + return os.path.join(self.ctx.started_task_dir, f"{self.node_id}.{self.key}.{self.retry_count}") @property def ray_dataset_path(self) -> str: @@ -827,30 +753,22 @@ class Task(WorkItem): If this file exists, the task is considered finished. Only used by the ray executor. """ - return os.path.join( - self.ctx.completed_task_dir, str(self.node_id), f"{self.key}.pickle" - ) + return os.path.join(self.ctx.completed_task_dir, str(self.node_id), f"{self.key}.pickle") @property def random_seed_bytes(self) -> bytes: - return self.id.to_bytes(4, sys.byteorder) + self.ctx.random_seed.to_bytes( - RAND_SEED_BYTE_LEN, sys.byteorder - ) + return self.id.to_bytes(4, sys.byteorder) + self.ctx.random_seed.to_bytes(RAND_SEED_BYTE_LEN, sys.byteorder) @property def numpy_random_gen(self): if self._np_randgen is None: - self._np_randgen = np.random.default_rng( - int.from_bytes(self.random_seed_bytes, sys.byteorder) - ) + self._np_randgen = np.random.default_rng(int.from_bytes(self.random_seed_bytes, sys.byteorder)) return self._np_randgen @property def python_random_gen(self): if self._py_randgen is None: - self._py_randgen = random.Random( - int.from_bytes(self.random_seed_bytes, sys.byteorder) - ) + self._py_randgen = random.Random(int.from_bytes(self.random_seed_bytes, sys.byteorder)) return self._py_randgen def random_uint32(self) -> int: @@ -866,10 +784,7 @@ class Task(WorkItem): def inject_fault(self): if self.ctx.fault_inject_prob > 0 and self.fail_count <= 1: random_value = self.random_float() - if ( - random_value < self.uniform_failure_prob - and random_value < self.ctx.fault_inject_prob - ): + if random_value < self.uniform_failure_prob and random_value < self.ctx.fault_inject_prob: raise InjectedFault( f"inject fault to {repr(self)}, uniform_failure_prob={self.uniform_failure_prob:.6f}, fault_inject_prob={self.ctx.fault_inject_prob:.6f}" ) @@ -890,21 +805,15 @@ class Task(WorkItem): if num_row_groups > max_num_row_groups: parquet_row_group_size = round_up( - clamp_row_group_size( - num_rows // max_num_row_groups, maxval=max_row_group_size - ), + clamp_row_group_size(num_rows // max_num_row_groups, maxval=max_row_group_size), KB, ) avg_row_size = self.compute_avg_row_size(nbytes, num_rows) - parquet_row_group_size = round_up( - min(parquet_row_group_size, max_row_group_bytes // avg_row_size), KB - ) + parquet_row_group_size = round_up(min(parquet_row_group_size, max_row_group_bytes // avg_row_size), KB) if self.parquet_row_group_size != parquet_row_group_size: parquet_row_group_bytes = round_up( - clamp_row_group_bytes( - parquet_row_group_size * avg_row_size, maxval=max_row_group_bytes - ), + clamp_row_group_bytes(parquet_row_group_size * avg_row_size, maxval=max_row_group_bytes), MB, ) logger.info( @@ -919,19 +828,13 @@ class Task(WorkItem): def set_memory_limit(self, soft_limit: int, hard_limit: int): soft_old, hard_old = resource.getrlimit(resource.RLIMIT_DATA) resource.setrlimit(resource.RLIMIT_DATA, (soft_limit, hard_limit)) - logger.debug( - f"updated memory limit from ({soft_old/GB:.3f}GB, {hard_old/GB:.3f}GB) to ({soft_limit/GB:.3f}GB, {hard_limit/GB:.3f}GB)" - ) + logger.debug(f"updated memory limit from ({soft_old/GB:.3f}GB, {hard_old/GB:.3f}GB) to ({soft_limit/GB:.3f}GB, {hard_limit/GB:.3f}GB)") def initialize(self): self.inject_fault() if self._memory_limit is None: - self._memory_limit = np.int64( - self.ctx.usable_memory_size - * self._cpu_limit - // self.ctx.usable_cpu_count - ) + self._memory_limit = np.int64(self.ctx.usable_memory_size * self._cpu_limit // self.ctx.usable_cpu_count) assert self.partition_infos, f"empty partition infos: {self}" os.makedirs(self.runtime_output_abspath, exist_ok=self.output_root is not None) os.makedirs(self.temp_abspath, exist_ok=False) @@ -941,9 +844,7 @@ class Task(WorkItem): self.perf_profile = cProfile.Profile() self.perf_profile.enable() if self.ctx.enforce_memory_limit: - self.set_memory_limit( - round_up(self.memory_limit * 1.2), round_up(self.memory_limit * 1.5) - ) + self.set_memory_limit(round_up(self.memory_limit * 1.2), round_up(self.memory_limit * 1.5)) if self.ctx.remove_empty_parquet: for dataset in self.input_datasets: if isinstance(dataset, ParquetDataSet): @@ -952,9 +853,7 @@ class Task(WorkItem): logger.debug("input datasets: {}", self.input_datasets) logger.trace(f"final output directory: {self.final_output_abspath}") logger.trace(f"runtime output directory: {self.runtime_output_abspath}") - logger.trace( - f"resource limit: {self.cpu_limit} cpus, {self.gpu_limit} gpus, {self.memory_limit/GB:.3f}GB memory" - ) + logger.trace(f"resource limit: {self.cpu_limit} cpus, {self.gpu_limit} gpus, {self.memory_limit/GB:.3f}GB memory") random.seed(self.random_seed_bytes) arrow.set_cpu_count(self.cpu_limit) arrow.set_io_thread_count(self.cpu_limit) @@ -967,9 +866,7 @@ class Task(WorkItem): logger.info("finished task: {}", self) # move the task output from staging dir to output dir - if self.runtime_output_abspath != self.final_output_abspath and os.path.exists( - self.runtime_output_abspath - ): + if self.runtime_output_abspath != self.final_output_abspath and os.path.exists(self.runtime_output_abspath): os.makedirs(os.path.dirname(self.final_output_abspath), exist_ok=True) os.rename(self.runtime_output_abspath, self.final_output_abspath) @@ -980,18 +877,12 @@ class Task(WorkItem): with ThreadPoolExecutor(min(32, len(file_paths))) as pool: file_sizes = list(pool.map(os.path.getsize, file_paths)) except FileNotFoundError: - logger.warning( - f"some of the output files not found: {file_paths[:3]}..." - ) + logger.warning(f"some of the output files not found: {file_paths[:3]}...") file_sizes = [] return file_sizes if self.ctx.enable_diagnostic_metrics: - input_file_paths = [ - path - for dataset in self.input_datasets - for path in dataset.resolved_paths - ] + input_file_paths = [path for dataset in self.input_datasets for path in dataset.resolved_paths] output_file_paths = self.output.resolved_paths for metric_name, file_paths in [ ("input", input_file_paths), @@ -1000,26 +891,18 @@ class Task(WorkItem): file_sizes = collect_file_sizes(file_paths) if file_paths and file_sizes: self.perf_metrics[f"num {metric_name} files"] += len(file_paths) - self.perf_metrics[f"total {metric_name} size (MB)"] += ( - sum(file_sizes) / MB - ) + self.perf_metrics[f"total {metric_name} size (MB)"] += sum(file_sizes) / MB self.perf_metrics["elapsed wall time (secs)"] += self.elapsed_time if not self.exec_on_scheduler: resource_usage = resource.getrusage(resource.RUSAGE_SELF) - self.perf_metrics["max resident set size (MB)"] += ( - resource_usage.ru_maxrss / 1024 - ) + self.perf_metrics["max resident set size (MB)"] += resource_usage.ru_maxrss / 1024 self.perf_metrics["user mode cpu time (secs)"] += resource_usage.ru_utime self.perf_metrics["system mode cpu time (secs)"] += resource_usage.ru_stime - logger.debug( - f"{self.key} perf metrics:{os.linesep}{os.linesep.join(f'{name}: {value}' for name, value in self.perf_metrics.items())}" - ) + logger.debug(f"{self.key} perf metrics:{os.linesep}{os.linesep.join(f'{name}: {value}' for name, value in self.perf_metrics.items())}") if self.perf_profile is not None and self.elapsed_time > 3: - logger.debug( - f"{self.key} perf profile:{os.linesep}{cprofile_to_string(self.perf_profile)}" - ) + logger.debug(f"{self.key} perf profile:{os.linesep}{cprofile_to_string(self.perf_profile)}") def cleanup(self): if self.perf_profile is not None: @@ -1038,28 +921,17 @@ class Task(WorkItem): def is_primitive_iterable(obj: Any): if isinstance(obj, dict): - return all( - is_primitive(key) and is_primitive(value) - for key, value in obj.items() - ) + return all(is_primitive(key) and is_primitive(value) for key, value in obj.items()) elif isinstance(obj, Iterable): return all(is_primitive(elem) for elem in obj) return False if hasattr(self, "__dict__"): complex_attrs = [ - attr - for attr, obj in vars(self).items() - if not ( - attr in self._pristine_attrs - or is_primitive(obj) - or is_primitive_iterable(obj) - ) + attr for attr, obj in vars(self).items() if not (attr in self._pristine_attrs or is_primitive(obj) or is_primitive_iterable(obj)) ] if complex_attrs: - logger.debug( - f"removing complex attributes not explicitly declared in __slots__: {complex_attrs}" - ) + logger.debug(f"removing complex attributes not explicitly declared in __slots__: {complex_attrs}") for attr in complex_attrs: delattr(self, attr) @@ -1087,9 +959,7 @@ class Task(WorkItem): ``` """ self.inject_fault() - assert ( - self._timer_start is not None or metric_name is None - ), f"timer not started, cannot save '{metric_name}'" + assert self._timer_start is not None or metric_name is None, f"timer not started, cannot save '{metric_name}'" if self._timer_start is None or metric_name is None: self._timer_start = time.time() return 0.0 @@ -1141,13 +1011,9 @@ class Task(WorkItem): while os.path.exists(task.ray_marker_path): task.retry_count += 1 if task.retry_count > DEFAULT_MAX_RETRY_COUNT: - raise RuntimeError( - f"task {task.key} failed after {task.retry_count} retries" - ) + raise RuntimeError(f"task {task.key} failed after {task.retry_count} retries") if task.retry_count > 0: - logger.warning( - f"task {task.key} is being retried for the {task.retry_count}th time" - ) + logger.warning(f"task {task.key} is being retried for the {task.retry_count}th time") # create the marker file Path(task.ray_marker_path).touch() @@ -1157,9 +1023,7 @@ class Task(WorkItem): # execute the task status = task.exec() if status != WorkStatus.SUCCEED: - raise task.exception or RuntimeError( - f"task {task.key} failed with status {status}" - ) + raise task.exception or RuntimeError(f"task {task.key} failed with status {status}") # dump the output dataset atomically os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True) @@ -1180,14 +1044,9 @@ class Task(WorkItem): # because dataset is distributed on ray ) try: - self._dataset_ref = remote_function.remote( - task, *[dep.run_on_ray() for dep in self.input_deps.values()] - ) + self._dataset_ref = remote_function.remote(task, *[dep.run_on_ray() for dep in self.input_deps.values()]) except RuntimeError as e: - if ( - "SimpleQueue objects should only be shared between processes through inheritance" - in str(e) - ): + if "SimpleQueue objects should only be shared between processes through inheritance" in str(e): raise RuntimeError( f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n" f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." @@ -1226,26 +1085,19 @@ class ExecSqlQueryMixin(Task): @property def compression_type_str(self): - return ( - f"COMPRESSION '{self.parquet_compression}'" - if self.parquet_compression is not None - else "COMPRESSION 'uncompressed'" - ) + return f"COMPRESSION '{self.parquet_compression}'" if self.parquet_compression is not None else "COMPRESSION 'uncompressed'" @property def compression_level_str(self): return ( f"COMPRESSION_LEVEL {self.parquet_compression_level}" - if self.parquet_compression == "ZSTD" - and self.parquet_compression_level is not None + if self.parquet_compression == "ZSTD" and self.parquet_compression_level is not None else "" ) @property def compression_options(self): - return ", ".join( - filter(None, (self.compression_type_str, self.compression_level_str)) - ) + return ", ".join(filter(None, (self.compression_type_str, self.compression_level_str))) def prepare_connection(self, conn: duckdb.DuckDBPyConnection): logger.debug(f"duckdb version: {duckdb.__version__}") @@ -1253,9 +1105,7 @@ class ExecSqlQueryMixin(Task): self.exec_query(conn, f"select setseed({self.rand_seed_float})") # prepare connection effective_cpu_count = math.ceil(self.cpu_limit * self.cpu_overcommit_ratio) - effective_memory_size = round_up( - self.memory_limit * self.memory_overcommit_ratio, MB - ) + effective_memory_size = round_up(self.memory_limit * self.memory_overcommit_ratio, MB) self.exec_query( conn, f""" @@ -1282,9 +1132,7 @@ class ExecSqlQueryMixin(Task): for input_dataset in input_datasets: self.input_view_index += 1 view_name = f"{INPUT_VIEW_PREFIX}_{self.id}_{self.input_view_index:06d}" - input_views[view_name] = ( - f"CREATE VIEW {view_name} AS {input_dataset.sql_query_fragment(filesystem, conn)};" - ) + input_views[view_name] = f"CREATE VIEW {view_name} AS {input_dataset.sql_query_fragment(filesystem, conn)};" logger.debug(f"create input view '{view_name}': {input_views[view_name]}") conn.sql(input_views[view_name]) return list(input_views.keys()) @@ -1303,11 +1151,7 @@ class ExecSqlQueryMixin(Task): if log_query: logger.debug(f"running sql query: {query_statement}") start_time = time.time() - query_output = conn.sql( - "SET enable_profiling='json';" - if enable_profiling - else "RESET enable_profiling;" - ) + query_output = conn.sql("SET enable_profiling='json';" if enable_profiling else "RESET enable_profiling;") query_output = conn.sql(query_statement) elapsed_time = time.time() - start_time if log_query: @@ -1332,10 +1176,7 @@ class ExecSqlQueryMixin(Task): perf_metrics["num input rows"] += obj["operator_cardinality"] perf_metrics["input load time (secs)"] += obj["operator_timing"] elif name.startswith("COPY_TO_FILE"): - perf_metrics["num output rows"] += sum( - sum_children_metrics(child, "operator_cardinality") - for child in obj["children"] - ) + perf_metrics["num output rows"] += sum(sum_children_metrics(child, "operator_cardinality") for child in obj["children"]) perf_metrics["output dump time (secs)"] += obj["operator_timing"] return obj @@ -1343,9 +1184,7 @@ class ExecSqlQueryMixin(Task): output_rows = query_output.fetchall() if log_output or (enable_profiling and self.ctx.enable_profiling): for row in output_rows: - logger.debug( - f"query output:{os.linesep}{''.join(filter(None, row))}" - ) + logger.debug(f"query output:{os.linesep}{''.join(filter(None, row))}") if enable_profiling: _, json_str = output_rows[0] json.loads(json_str, object_hook=extract_perf_metrics) @@ -1373,9 +1212,7 @@ class DataSourceTask(Task): def run(self) -> bool: logger.info(f"added data source: {self.dataset}") if isinstance(self.dataset, (SqlQueryDataSet, ArrowTableDataSet)): - self.dataset = ParquetDataSet.create_from( - self.dataset.to_arrow_table(), self.runtime_output_abspath - ) + self.dataset = ParquetDataSet.create_from(self.dataset.to_arrow_table(), self.runtime_output_abspath) return True @@ -1397,9 +1234,7 @@ class MergeDataSetsTask(Task): def run(self) -> bool: datasets = self.input_datasets assert datasets, f"empty list of input datasets: {self}" - assert all( - isinstance(dataset, (DataSet, type(datasets[0]))) for dataset in datasets - ) + assert all(isinstance(dataset, (DataSet, type(datasets[0]))) for dataset in datasets) self.dataset = datasets[0].merge(datasets) logger.info(f"created merged dataset: {self.dataset}") return True @@ -1418,9 +1253,7 @@ class SplitDataSetTask(Task): input_deps: List[Task], partition_infos: List[PartitionInfo], ) -> None: - assert ( - len(input_deps) == 1 - ), f"wrong number of input deps for data set partition: {input_deps}" + assert len(input_deps) == 1, f"wrong number of input deps for data set partition: {input_deps}" super().__init__(ctx, input_deps, partition_infos) self.partition = partition_infos[-1].index self.npartitions = partition_infos[-1].npartitions @@ -1440,9 +1273,7 @@ class SplitDataSetTask(Task): pass def run(self) -> bool: - self.dataset = self.input_datasets[0].partition_by_files(self.npartitions)[ - self.partition - ] + self.dataset = self.input_datasets[0].partition_by_files(self.npartitions)[self.partition] return True @@ -1467,9 +1298,7 @@ class PartitionProducerTask(Task): memory_limit: int = None, ) -> None: assert len(input_deps) == 1, f"wrong number of inputs: {input_deps}" - assert isinstance( - npartitions, int - ), f"npartitions is not an integer: {npartitions}" + assert isinstance(npartitions, int), f"npartitions is not an integer: {npartitions}" super().__init__( ctx, input_deps, @@ -1485,10 +1314,7 @@ class PartitionProducerTask(Task): self.partitioned_datasets: List[DataSet] = None def __str__(self) -> str: - return ( - super().__str__() - + f", npartitions={self.npartitions}, dimension={self.dimension}" - ) + return super().__str__() + f", npartitions={self.npartitions}, dimension={self.dimension}" def _create_empty_file(self, partition_idx: int, dataset: DataSet) -> str: """ @@ -1523,14 +1349,12 @@ class PartitionProducerTask(Task): if not isinstance(self, HashPartitionTask) else [ PartitionInfo(partition_idx, self.npartitions, self.dimension), - PartitionInfo( - partition_idx, self.npartitions, self.data_partition_column - ), + PartitionInfo(partition_idx, self.npartitions, self.data_partition_column), ] ) - schema_with_metadata = filter_schema( - dataset_schema, excluded_cols=GENERATED_COLUMNS - ).with_metadata(self.parquet_kv_metadata_bytes(extra_partitions)) + schema_with_metadata = filter_schema(dataset_schema, excluded_cols=GENERATED_COLUMNS).with_metadata( + self.parquet_kv_metadata_bytes(extra_partitions) + ) empty_file_path = Path(empty_file_prefix + ".parquet") parquet.ParquetWriter(empty_file_path, schema_with_metadata).close() else: @@ -1550,41 +1374,27 @@ class PartitionProducerTask(Task): else: # Create an empty file for each empty partition. # This is to ensure that partition consumers have at least one file to read. - empty_partitions = [ - idx for idx, empty in enumerate(is_empty_partition) if empty - ] - nonempty_partitions = [ - idx for idx, empty in enumerate(is_empty_partition) if not empty - ] + empty_partitions = [idx for idx, empty in enumerate(is_empty_partition) if empty] + nonempty_partitions = [idx for idx, empty in enumerate(is_empty_partition) if not empty] first_nonempty_dataset = self.partitioned_datasets[nonempty_partitions[0]] if empty_partitions: with ThreadPoolExecutor(self.cpu_limit) as pool: empty_file_paths = list( pool.map( - lambda idx: self._create_empty_file( - idx, first_nonempty_dataset - ), + lambda idx: self._create_empty_file(idx, first_nonempty_dataset), empty_partitions, ) ) - for partition_idx, empty_file_path in zip( - empty_partitions, empty_file_paths - ): - self.partitioned_datasets[partition_idx].reset( - [empty_file_path], self.runtime_output_abspath - ) - logger.debug( - f"created empty output files in partitions {empty_partitions} of {repr(self)}: {empty_file_paths[:3]}..." - ) + for partition_idx, empty_file_path in zip(empty_partitions, empty_file_paths): + self.partitioned_datasets[partition_idx].reset([empty_file_path], self.runtime_output_abspath) + logger.debug(f"created empty output files in partitions {empty_partitions} of {repr(self)}: {empty_file_paths[:3]}...") # reset root_dir from runtime_output_abspath to final_output_abspath for dataset in self.partitioned_datasets: # XXX: if the task has output in `runtime_output_abspath`, # `root_dir` must be set and all row ranges must be full ranges. if dataset.root_dir == self.runtime_output_abspath: - dataset.reset( - dataset.paths, self.final_output_abspath, dataset.recursive - ) + dataset.reset(dataset.paths, self.final_output_abspath, dataset.recursive) # XXX: otherwise, we assume there is no output in `runtime_output_abspath`. # do nothing to the dataset. self.dataset = PartitionedDataSet(self.partitioned_datasets) @@ -1628,9 +1438,7 @@ class RepeatPartitionProducerTask(PartitionProducerTask): pass def run(self) -> bool: - self.partitioned_datasets = [ - self.input_datasets[0] for _ in range(self.npartitions) - ] + self.partitioned_datasets = [self.input_datasets[0] for _ in range(self.npartitions)] return True @@ -1662,9 +1470,7 @@ class UserDefinedPartitionProducerTask(PartitionProducerTask): def run(self) -> bool: try: - self.partitioned_datasets = self.partition_func( - self.ctx, self.input_datasets[0] - ) + self.partitioned_datasets = self.partition_func(self.ctx, self.input_datasets[0]) return True finally: self.partition_func = None @@ -1715,13 +1521,9 @@ class EvenlyDistributedPartitionProducerTask(PartitionProducerTask): self.partition_by_rows and not isinstance(input_dataset, ParquetDataSet) ), f"Only parquet dataset supports partition by rows, found: {input_dataset}" if isinstance(input_dataset, ParquetDataSet) and self.partition_by_rows: - self.partitioned_datasets = input_dataset.partition_by_rows( - self.npartitions, self.random_shuffle - ) + self.partitioned_datasets = input_dataset.partition_by_rows(self.npartitions, self.random_shuffle) else: - self.partitioned_datasets = input_dataset.partition_by_files( - self.npartitions, self.random_shuffle - ) + self.partitioned_datasets = input_dataset.partition_by_files(self.npartitions, self.random_shuffle) return True @@ -1758,12 +1560,8 @@ class LoadPartitionedDataSetProducerTask(PartitionProducerTask): def run(self) -> bool: input_dataset = self.input_datasets[0] - assert isinstance( - input_dataset, ParquetDataSet - ), f"Not parquet dataset: {input_dataset}" - self.partitioned_datasets = input_dataset.load_partitioned_datasets( - self.npartitions, self.data_partition_column, self.hive_partitioning - ) + assert isinstance(input_dataset, ParquetDataSet), f"Not parquet dataset: {input_dataset}" + self.partitioned_datasets = input_dataset.load_partitioned_datasets(self.npartitions, self.data_partition_column, self.hive_partitioning) return True @@ -1801,12 +1599,8 @@ class PartitionConsumerTask(Task): def run(self) -> bool: # Build the dataset only after all `input_deps` finished, since `input_deps` could be tried multiple times. # Consumers always follow producers, so the input is a list of partitioned datasets. - assert all( - isinstance(dataset, PartitionedDataSet) for dataset in self.input_datasets - ) - datasets = [ - dataset[self.last_partition.index] for dataset in self.input_datasets - ] + assert all(isinstance(dataset, PartitionedDataSet) for dataset in self.input_datasets) + datasets = [dataset[self.last_partition.index] for dataset in self.input_datasets] self.dataset = datasets[0].merge(datasets) if self.ctx.remove_empty_parquet and isinstance(self.dataset, ParquetDataSet): @@ -1938,9 +1732,7 @@ class ArrowComputeTask(ExecSqlQueryMixin, Task): ) self.process_func = process_func self.parquet_row_group_size = parquet_row_group_size - self.parquet_row_group_bytes = clamp_row_group_bytes( - parquet_row_group_size * 4 * KB - ) + self.parquet_row_group_bytes = clamp_row_group_bytes(parquet_row_group_size * 4 * KB) self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level @@ -1951,14 +1743,10 @@ class ArrowComputeTask(ExecSqlQueryMixin, Task): self.process_func = None super().clean_complex_attrs() - def _call_process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def _call_process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: return self.process(runtime_ctx, input_tables) - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: """ This method can be overridden in subclass of `ArrowComputeTask`. @@ -1983,25 +1771,16 @@ class ArrowComputeTask(ExecSqlQueryMixin, Task): return True if self.use_duckdb_reader: - conn = duckdb.connect( - database=":memory:", config={"allow_unsigned_extensions": "true"} - ) + conn = duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) self.prepare_connection(conn) - input_tables = [ - dataset.to_arrow_table(max_workers=self.cpu_limit, conn=conn) - for dataset in self.input_datasets - ] - self.perf_metrics["num input rows"] += sum( - table.num_rows for table in input_tables - ) + input_tables = [dataset.to_arrow_table(max_workers=self.cpu_limit, conn=conn) for dataset in self.input_datasets] + self.perf_metrics["num input rows"] += sum(table.num_rows for table in input_tables) self.add_elapsed_time("input load time (secs)") if conn is not None: conn.close() - output_table = self._call_process( - self.ctx.set_current_task(self), input_tables - ) + output_table = self._call_process(self.ctx.set_current_task(self), input_tables) self.add_elapsed_time("compute time (secs)") return self.dump_output(output_table) @@ -2025,11 +1804,7 @@ class ArrowComputeTask(ExecSqlQueryMixin, Task): output_table.replace_schema_metadata(self.parquet_kv_metadata_bytes()), self.runtime_output_abspath, self.output_filename, - compression=( - self.parquet_compression - if self.parquet_compression is not None - else "NONE" - ), + compression=(self.parquet_compression if self.parquet_compression is not None else "NONE"), compression_level=self.parquet_compression_level, row_group_size=self.parquet_row_group_size, row_group_bytes=self.parquet_row_group_bytes, @@ -2089,9 +1864,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): "streaming_batch_count", ) - def __init__( - self, streaming_batch_size: int, streaming_batch_count: int - ) -> None: + def __init__(self, streaming_batch_size: int, streaming_batch_count: int) -> None: self.last_batch_indices: List[int] = None self.input_batch_offsets: List[int] = None self.streaming_output_paths: List[str] = [] @@ -2112,12 +1885,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): self.last_batch_indices = [-1] * len(batch_indices) if self.input_batch_offsets is None: self.input_batch_offsets = [0] * len(batch_indices) - self.input_batch_offsets = [ - i + j - k - for i, j, k in zip( - self.input_batch_offsets, batch_indices, self.last_batch_indices - ) - ] + self.input_batch_offsets = [i + j - k for i, j, k in zip(self.input_batch_offsets, batch_indices, self.last_batch_indices)] self.last_batch_indices = batch_indices def reset(self): @@ -2129,9 +1897,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): ctx: RuntimeContext, input_deps: List[Task], partition_infos: List[PartitionInfo], - process_func: Callable[ - [RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table] - ] = None, + process_func: Callable[[RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table]] = None, background_io_thread=True, streaming_batch_size: int = DEFAULT_BATCH_SIZE, secs_checkpoint_interval: int = None, @@ -2161,16 +1927,12 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): self.streaming_batch_size = streaming_batch_size self.streaming_batch_count = 1 self.parquet_row_group_size = parquet_row_group_size - self.parquet_row_group_bytes = clamp_row_group_bytes( - parquet_row_group_size * 4 * KB - ) + self.parquet_row_group_bytes = clamp_row_group_bytes(parquet_row_group_size * 4 * KB) self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level self.use_duckdb_reader = use_duckdb_reader - self.secs_checkpoint_interval = ( - secs_checkpoint_interval or self.ctx.secs_executor_probe_timeout - ) + self.secs_checkpoint_interval = secs_checkpoint_interval or self.ctx.secs_executor_probe_timeout self.runtime_state: Optional[ArrowStreamTask.RuntimeState] = None def __str__(self) -> str: @@ -2189,9 +1951,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): if not path.startswith(self.runtime_output_abspath): os.link( path, - os.path.join( - self.runtime_output_abspath, os.path.basename(path) - ), + os.path.join(self.runtime_output_abspath, os.path.basename(path)), ) self.runtime_state = None super().finalize() @@ -2201,9 +1961,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): self.process_func = None super().clean_complex_attrs() - def _wrap_output( - self, output: Union[arrow.Table, StreamOutput], batch_indices: List[int] = None - ) -> StreamOutput: + def _wrap_output(self, output: Union[arrow.Table, StreamOutput], batch_indices: List[int] = None) -> StreamOutput: if isinstance(output, StreamOutput): assert len(output.batch_indices) == 0 or len(output.batch_indices) == len( self.input_deps @@ -2213,15 +1971,11 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): assert isinstance(output, arrow.Table) return StreamOutput(output, batch_indices) - def _call_process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[StreamOutput]: + def _call_process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[StreamOutput]: for output in self.process(runtime_ctx, input_readers): yield self._wrap_output(output) - def process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[arrow.Table]: + def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: """ This method can be overridden in subclass of `ArrowStreamTask`. @@ -2238,26 +1992,20 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): """ return self.process_func(runtime_ctx, input_readers) - def restore_input_state( - self, runtime_state: RuntimeState, input_readers: List[arrow.RecordBatchReader] - ): + def restore_input_state(self, runtime_state: RuntimeState, input_readers: List[arrow.RecordBatchReader]): logger.info(f"restore input state to: {runtime_state}") assert len(runtime_state.input_batch_offsets) == len( input_readers ), f"num of batch offsets {len(runtime_state.input_batch_offsets)} not equal to num of input readers {len(input_readers)}" - for batch_offset, input_reader in zip( - runtime_state.input_batch_offsets, input_readers - ): + for batch_offset, input_reader in zip(runtime_state.input_batch_offsets, input_readers): if batch_offset <= 0: continue for ( batch_index, input_batch, ) in enumerate(input_reader): - logger.debug( - f"skipped input batch #{batch_index}: {input_batch.num_rows} rows" - ) + logger.debug(f"skipped input batch #{batch_index}: {input_batch.num_rows} rows") if batch_index + 1 == batch_offset: break assert batch_index + 1 <= batch_offset @@ -2267,31 +2015,15 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): if self.skip_when_any_input_empty: return True - input_row_ranges = [ - dataset.resolved_row_ranges - for dataset in self.input_datasets - if isinstance(dataset, ParquetDataSet) - ] - input_byte_size = [ - sum(row_range.estimated_data_size for row_range in row_ranges) - for row_ranges in input_row_ranges - ] - input_num_rows = [ - sum(row_range.num_rows for row_range in row_ranges) - for row_ranges in input_row_ranges - ] - input_files = [ - set(row_range.path for row_range in row_ranges) - for row_ranges in input_row_ranges - ] + input_row_ranges = [dataset.resolved_row_ranges for dataset in self.input_datasets if isinstance(dataset, ParquetDataSet)] + input_byte_size = [sum(row_range.estimated_data_size for row_range in row_ranges) for row_ranges in input_row_ranges] + input_num_rows = [sum(row_range.num_rows for row_range in row_ranges) for row_ranges in input_row_ranges] + input_files = [set(row_range.path for row_range in row_ranges) for row_ranges in input_row_ranges] self.perf_metrics["num input rows"] += sum(input_num_rows) self.perf_metrics["input data size (MB)"] += sum(input_byte_size) / MB # calculate the max streaming batch size based on memory limit - avg_input_row_size = sum( - self.compute_avg_row_size(nbytes, num_rows) - for nbytes, num_rows in zip(input_byte_size, input_num_rows) - ) + avg_input_row_size = sum(self.compute_avg_row_size(nbytes, num_rows) for nbytes, num_rows in zip(input_byte_size, input_num_rows)) max_batch_rows = self.max_batch_size // avg_input_row_size if self.runtime_state is None: @@ -2312,9 +2044,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): try: conn = None if self.use_duckdb_reader: - conn = duckdb.connect( - database=":memory:", config={"allow_unsigned_extensions": "true"} - ) + conn = duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) self.prepare_connection(conn) input_readers = [ @@ -2326,16 +2056,12 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): ] if self.runtime_state is None: - self.runtime_state = ArrowStreamTask.RuntimeState( - self.streaming_batch_size, self.streaming_batch_count - ) + self.runtime_state = ArrowStreamTask.RuntimeState(self.streaming_batch_size, self.streaming_batch_count) else: self.restore_input_state(self.runtime_state, input_readers) self.runtime_state.last_batch_indices = None - output_iter = self._call_process( - self.ctx.set_current_task(self), input_readers - ) + output_iter = self._call_process(self.ctx.set_current_task(self), input_readers) self.add_elapsed_time("compute time (secs)") if self.background_io_thread: @@ -2358,9 +2084,7 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): self.add_elapsed_time("output dump time (secs)") create_checkpoint = False - last_checkpoint_time = ( - time.time() - self.random_float() * self.secs_checkpoint_interval / 2 - ) + last_checkpoint_time = time.time() - self.random_float() * self.secs_checkpoint_interval / 2 output: StreamOutput = next(output_iter, None) self.add_elapsed_time("compute time (secs)") @@ -2389,15 +2113,9 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): try: with parquet.ParquetWriter( where=output_file, - schema=buffered_output.schema.with_metadata( - self.parquet_kv_metadata_bytes() - ), + schema=buffered_output.schema.with_metadata(self.parquet_kv_metadata_bytes()), use_dictionary=self.parquet_dictionary_encoding, - compression=( - self.parquet_compression - if self.parquet_compression is not None - else "NONE" - ), + compression=(self.parquet_compression if self.parquet_compression is not None else "NONE"), compression_level=self.parquet_compression_level, write_batch_size=max(16 * 1024, self.parquet_row_group_size // 8), data_page_size=max(64 * MB, self.parquet_row_group_bytes // 8), @@ -2406,30 +2124,17 @@ class ArrowStreamTask(ExecSqlQueryMixin, Task): while (output := next(output_iter, None)) is not None: self.add_elapsed_time("compute time (secs)") - if ( - buffered_output.num_rows + output.output_table.num_rows - < self.parquet_row_group_size - ): - buffered_output = arrow.concat_tables( - (buffered_output, output.output_table) - ) + if buffered_output.num_rows + output.output_table.num_rows < self.parquet_row_group_size: + buffered_output = arrow.concat_tables((buffered_output, output.output_table)) else: write_table(writer, buffered_output) buffered_output = output.output_table - periodic_checkpoint = ( - bool(output.batch_indices) - and (time.time() - last_checkpoint_time) - >= self.secs_checkpoint_interval - ) - create_checkpoint = ( - output.force_checkpoint or periodic_checkpoint - ) + periodic_checkpoint = bool(output.batch_indices) and (time.time() - last_checkpoint_time) >= self.secs_checkpoint_interval + create_checkpoint = output.force_checkpoint or periodic_checkpoint if create_checkpoint: - self.runtime_state.update_batch_offsets( - output.batch_indices - ) + self.runtime_state.update_batch_offsets(output.batch_indices) last_checkpoint_time = time.time() break @@ -2463,72 +2168,41 @@ class ArrowBatchTask(ArrowStreamTask): def max_batch_size(self) -> int: return self._memory_limit // 3 - def _call_process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[arrow.Table]: + def _call_process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: with contextlib.ExitStack() as stack: - opened_readers = [ - stack.enter_context( - ConcurrentIter(reader) if self.background_io_thread else reader - ) - for reader in input_readers - ] - for batch_index, input_batches in enumerate( - itertools.zip_longest(*opened_readers, fillvalue=None) - ): + opened_readers = [stack.enter_context(ConcurrentIter(reader) if self.background_io_thread else reader) for reader in input_readers] + for batch_index, input_batches in enumerate(itertools.zip_longest(*opened_readers, fillvalue=None)): input_tables = [ - ( - reader.schema.empty_table() - if batch is None - else arrow.Table.from_batches([batch], reader.schema) - ) + (reader.schema.empty_table() if batch is None else arrow.Table.from_batches([batch], reader.schema)) for reader, batch in zip(input_readers, input_batches) ] output_table = self._process_batches(runtime_ctx, input_tables) - yield self._wrap_output( - output_table, [batch_index] * len(input_batches) - ) + yield self._wrap_output(output_table, [batch_index] * len(input_batches)) - def _process_batches( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def _process_batches(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: return self.process(runtime_ctx, input_tables) - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: return self.process_func(runtime_ctx, input_tables) class PandasComputeTask(ArrowComputeTask): - def _call_process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def _call_process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: input_dfs = [table.to_pandas() for table in input_tables] output_df = self.process(runtime_ctx, input_dfs) - return ( - arrow.Table.from_pandas(output_df, preserve_index=False) - if output_df is not None - else None - ) + return arrow.Table.from_pandas(output_df, preserve_index=False) if output_df is not None else None - def process( - self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] - ) -> pd.DataFrame: + def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame: return self.process_func(runtime_ctx, input_dfs) class PandasBatchTask(ArrowBatchTask): - def _process_batches( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def _process_batches(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: input_dfs = [table.to_pandas() for table in input_tables] output_df = self.process(runtime_ctx, input_dfs) return arrow.Table.from_pandas(output_df, preserve_index=False) - def process( - self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] - ) -> pd.DataFrame: + def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame: return self.process_func(runtime_ctx, input_dfs) @@ -2593,25 +2267,17 @@ class SqlEngineTask(ExecSqlQueryMixin, Task): self.batched_processing = batched_processing and len(self.input_deps) == 1 self.enable_temp_directory = enable_temp_directory self.parquet_row_group_size = parquet_row_group_size - self.parquet_row_group_bytes = clamp_row_group_bytes( - parquet_row_group_size * 4 * KB - ) + self.parquet_row_group_bytes = clamp_row_group_bytes(parquet_row_group_size * 4 * KB) self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level def __str__(self) -> str: - return ( - super().__str__() - + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}" - ) + return super().__str__() + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}" @property def oneline_query(self) -> str: - return "; ".join( - " ".join(filter(None, map(str.strip, query.splitlines()))) - for query in self.sql_queries - ) + return "; ".join(" ".join(filter(None, map(str.strip, query.splitlines()))) for query in self.sql_queries) @property def max_batch_size(self) -> int: @@ -2625,22 +2291,13 @@ class SqlEngineTask(ExecSqlQueryMixin, Task): if self.skip_when_any_input_empty: return True - if self.batched_processing and isinstance( - self.input_datasets[0], ParquetDataSet - ): - input_batches = [ - [batch] - for batch in self.input_datasets[0].partition_by_size( - self.max_batch_size - ) - ] + if self.batched_processing and isinstance(self.input_datasets[0], ParquetDataSet): + input_batches = [[batch] for batch in self.input_datasets[0].partition_by_size(self.max_batch_size)] else: input_batches = [self.input_datasets] for batch_index, input_batch in enumerate(input_batches): - with duckdb.connect( - database=":memory:", config={"allow_unsigned_extensions": "true"} - ) as conn: + with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn: self.prepare_connection(conn) self.process_batch(batch_index, input_batch, conn) @@ -2656,13 +2313,9 @@ class SqlEngineTask(ExecSqlQueryMixin, Task): input_views = self.create_input_views(conn, input_datasets) if isinstance(self.parquet_dictionary_encoding, bool): - dictionary_encoding_cfg = ( - "DICTIONARY_ENCODING TRUE," if self.parquet_dictionary_encoding else "" - ) + dictionary_encoding_cfg = "DICTIONARY_ENCODING TRUE," if self.parquet_dictionary_encoding else "" else: - dictionary_encoding_cfg = "DICTIONARY_ENCODING ({}),".format( - ", ".join(self.parquet_dictionary_encoding) - ) + dictionary_encoding_cfg = "DICTIONARY_ENCODING ({}),".format(", ".join(self.parquet_dictionary_encoding)) for query_index, sql_query in enumerate(self.sql_queries): last_query = query_index + 1 == len(self.sql_queries) @@ -2709,11 +2362,7 @@ class SqlEngineTask(ExecSqlQueryMixin, Task): OVERWRITE_OR_IGNORE true) """ - self.merge_metrics( - self.exec_query( - conn, f"EXPLAIN ANALYZE {sql_query}", enable_profiling=True - ) - ) + self.merge_metrics(self.exec_query(conn, f"EXPLAIN ANALYZE {sql_query}", enable_profiling=True)) class HashPartitionTask(PartitionProducerTask): @@ -2770,9 +2419,7 @@ class HashPartitionTask(PartitionProducerTask): self.use_parquet_writer = use_parquet_writer self.hive_partitioning = hive_partitioning self.parquet_row_group_size = parquet_row_group_size - self.parquet_row_group_bytes = clamp_row_group_bytes( - parquet_row_group_size * 4 * KB - ) + self.parquet_row_group_bytes = clamp_row_group_bytes(parquet_row_group_size * 4 * KB) self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level @@ -2795,15 +2442,10 @@ class HashPartitionTask(PartitionProducerTask): self._file_writer_closed = True def __str__(self) -> str: - return ( - super().__str__() - + f", hash_columns={self.hash_columns}, data_partition_column={self.data_partition_column}" - ) + return super().__str__() + f", hash_columns={self.hash_columns}, data_partition_column={self.data_partition_column}" @staticmethod - def create( - engine_type: Literal["duckdb", "arrow"], *args, **kwargs - ) -> "HashPartitionTask": + def create(engine_type: Literal["duckdb", "arrow"], *args, **kwargs) -> "HashPartitionTask": if engine_type == "duckdb": return HashPartitionDuckDbTask(*args, *kwargs) if engine_type == "arrow": @@ -2820,9 +2462,7 @@ class HashPartitionTask(PartitionProducerTask): 4 * MB, round_up(min(16 * GB, self.max_batch_size) // self.npartitions, 16 * KB), ) - return ( - write_buffer_size if write_buffer_size >= 128 * KB else -1 - ) # disable write buffer if too small + return write_buffer_size if write_buffer_size >= 128 * KB else -1 # disable write buffer if too small @property def num_workers(self) -> int: @@ -2847,16 +2487,8 @@ class HashPartitionTask(PartitionProducerTask): self.add_elapsed_time() self._wait_pending_writes() if self._io_workers is not None: - list( - self._io_workers.map( - lambda w: w.close(), filter(None, self._partition_writers) - ) - ) - list( - self._io_workers.map( - lambda f: f.close(), filter(None, self._partition_files) - ) - ) + list(self._io_workers.map(lambda w: w.close(), filter(None, self._partition_writers))) + list(self._io_workers.map(lambda f: f.close(), filter(None, self._partition_files))) self._io_workers.shutdown(wait=True) self.add_elapsed_time("output dump time (secs)") @@ -2864,9 +2496,7 @@ class HashPartitionTask(PartitionProducerTask): partition_filename = f"{self.output_filename}-{partition_idx}.parquet" partition_path = os.path.join(self.runtime_output_abspath, partition_filename) - self._partition_files[partition_idx] = open( - partition_path, "wb", buffering=self.write_buffer_size - ) + self._partition_files[partition_idx] = open(partition_path, "wb", buffering=self.write_buffer_size) output_file = self._partition_files[partition_idx] self.partitioned_datasets[partition_idx].paths.append(partition_filename) @@ -2876,33 +2506,23 @@ class HashPartitionTask(PartitionProducerTask): self.parquet_kv_metadata_bytes( [ PartitionInfo(partition_idx, self.npartitions, self.dimension), - PartitionInfo( - partition_idx, self.npartitions, self.data_partition_column - ), + PartitionInfo(partition_idx, self.npartitions, self.data_partition_column), ] ) ), use_dictionary=self.parquet_dictionary_encoding, - compression=( - self.parquet_compression - if self.parquet_compression is not None - else "NONE" - ), + compression=(self.parquet_compression if self.parquet_compression is not None else "NONE"), compression_level=self.parquet_compression_level, write_batch_size=max(16 * 1024, self.parquet_row_group_size // 8), data_page_size=max(64 * MB, self.parquet_row_group_bytes // 8), ) return self._partition_writers[partition_idx] - def _write_to_partition( - self, partition_idx, partition, pending_write: Future = None - ): + def _write_to_partition(self, partition_idx, partition, pending_write: Future = None): if pending_write is not None: pending_write.result() if partition is not None: - writer = self._partition_writers[partition_idx] or self._create_file_writer( - partition_idx, partition.schema - ) + writer = self._partition_writers[partition_idx] or self._create_file_writer(partition_idx, partition.schema) writer.write_table(partition, self.parquet_row_group_size) def _write_partitioned_tables(self, partitioned_tables): @@ -2910,18 +2530,10 @@ class HashPartitionTask(PartitionProducerTask): assert len(self._pending_write_works) == self.npartitions self._pending_write_works = [ - self.io_workers.submit( - self._write_to_partition, partition_idx, partition, pending_write - ) - for partition_idx, (partition, pending_write) in enumerate( - zip(partitioned_tables, self._pending_write_works) - ) + self.io_workers.submit(self._write_to_partition, partition_idx, partition, pending_write) + for partition_idx, (partition, pending_write) in enumerate(zip(partitioned_tables, self._pending_write_works)) ] - self.perf_metrics["num output rows"] += sum( - partition.num_rows - for partition in partitioned_tables - if partition is not None - ) + self.perf_metrics["num output rows"] += sum(partition.num_rows for partition in partitioned_tables if partition is not None) self._wait_pending_writes() def initialize(self): @@ -2929,20 +2541,13 @@ class HashPartitionTask(PartitionProducerTask): if isinstance(self, HashPartitionDuckDbTask) and self.hive_partitioning: self.partitioned_datasets = [ ParquetDataSet( - [ - os.path.join( - f"{self.data_partition_column}={partition_idx}", "*.parquet" - ) - ], + [os.path.join(f"{self.data_partition_column}={partition_idx}", "*.parquet")], root_dir=self.runtime_output_abspath, ) for partition_idx in range(self.npartitions) ] else: - self.partitioned_datasets = [ - ParquetDataSet([], root_dir=self.runtime_output_abspath) - for _ in range(self.npartitions) - ] + self.partitioned_datasets = [ParquetDataSet([], root_dir=self.runtime_output_abspath) for _ in range(self.npartitions)] self._partition_files = [None] * self.npartitions self._partition_writers = [None] * self.npartitions self._pending_write_works = [None] * self.npartitions @@ -2972,16 +2577,12 @@ class HashPartitionTask(PartitionProducerTask): return True input_dataset = self.input_datasets[0] - assert isinstance( - input_dataset, ParquetDataSet - ), f"only parquet dataset supported, found {input_dataset}" + assert isinstance(input_dataset, ParquetDataSet), f"only parquet dataset supported, found {input_dataset}" input_paths = input_dataset.resolved_paths input_byte_size = input_dataset.estimated_data_size input_num_rows = input_dataset.num_rows - logger.info( - f"partitioning dataset: {len(input_paths)} files, {input_byte_size/GB:.3f}GB, {input_num_rows} rows" - ) + logger.info(f"partitioning dataset: {len(input_paths)} files, {input_byte_size/GB:.3f}GB, {input_num_rows} rows") input_batches = input_dataset.partition_by_size(self.max_batch_size) for batch_index, input_batch in enumerate(input_batches): @@ -2992,9 +2593,7 @@ class HashPartitionTask(PartitionProducerTask): f"start to partition batch #{batch_index+1}/{len(input_batches)}: {len(input_batch.resolved_paths)} files, {batch_byte_size/GB:.3f}GB, {batch_num_rows} rows" ) self.partition(batch_index, input_batch) - logger.info( - f"finished to partition batch #{batch_index+1}/{len(input_batches)}: {time.time() - batch_start_time:.3f} secs" - ) + logger.info(f"finished to partition batch #{batch_index+1}/{len(input_batches)}: {time.time() - batch_start_time:.3f} secs") return True @@ -3012,13 +2611,9 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): partition_query = r"SELECT * FROM {0}" else: if self.random_shuffle: - hash_values = ( - f"random() * {2147483647 // self.npartitions * self.npartitions}" - ) + hash_values = f"random() * {2147483647 // self.npartitions * self.npartitions}" else: - hash_values = ( - f"hash( concat_ws( '##', {', '.join(self.hash_columns)} ) )" - ) + hash_values = f"hash( concat_ws( '##', {', '.join(self.hash_columns)} ) )" partition_keys = f"CAST({hash_values} AS UINT64) % {self.npartitions}::UINT64 AS {self.data_partition_column}" partition_query = f""" SELECT *, @@ -3029,19 +2624,13 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): return partition_query def partition(self, batch_index: int, input_dataset: ParquetDataSet): - with duckdb.connect( - database=":memory:", config={"allow_unsigned_extensions": "true"} - ) as conn: + with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn: self.prepare_connection(conn) if self.hive_partitioning: - self.load_input_batch( - conn, batch_index, input_dataset, sort_by_partition_key=True - ) + self.load_input_batch(conn, batch_index, input_dataset, sort_by_partition_key=True) self.write_hive_partitions(conn, batch_index, input_dataset) else: - self.load_input_batch( - conn, batch_index, input_dataset, sort_by_partition_key=True - ) + self.load_input_batch(conn, batch_index, input_dataset, sort_by_partition_key=True) self.write_flat_partitions(conn, batch_index, input_dataset) def load_input_batch( @@ -3052,9 +2641,7 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): sort_by_partition_key=False, ): input_views = self.create_input_views(conn, [input_dataset]) - partition_query = self.partition_query.format( - *input_views, **self.partition_infos_as_dict - ) + partition_query = self.partition_query.format(*input_views, **self.partition_infos_as_dict) if sort_by_partition_key: partition_query += f" ORDER BY {self.data_partition_column}" @@ -3070,12 +2657,8 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): min_partition_key, max_partition_key = conn.sql( f"SELECT MIN({self.data_partition_column}), MAX({self.data_partition_column}) FROM temp_query_result" ).fetchall()[0] - assert ( - min_partition_key >= 0 - ), f"partition key {min_partition_key} is out of range 0-{self.npartitions-1}" - assert ( - max_partition_key < self.npartitions - ), f"partition key {max_partition_key} is out of range 0-{self.npartitions-1}" + assert min_partition_key >= 0, f"partition key {min_partition_key} is out of range 0-{self.npartitions-1}" + assert max_partition_key < self.npartitions, f"partition key {max_partition_key} is out of range 0-{self.npartitions-1}" logger.debug(f"load input dataset #{batch_index+1}: {elapsed_time:.3f} secs") @@ -3105,9 +2688,7 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): {"DICTIONARY_ENCODING TRUE," if self.parquet_dictionary_encoding else ""} FILENAME_PATTERN '{self.output_filename}-{batch_index}.{{i}}') """ - perf_metrics = self.exec_query( - conn, f"EXPLAIN ANALYZE {copy_query_result}", enable_profiling=True - ) + perf_metrics = self.exec_query(conn, f"EXPLAIN ANALYZE {copy_query_result}", enable_profiling=True) self.perf_metrics["num output rows"] += perf_metrics["num output rows"] elapsed_time = self.add_elapsed_time("output dump time (secs)") logger.debug(f"write partition data #{batch_index+1}: {elapsed_time:.3f} secs") @@ -3118,9 +2699,7 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): batch_index: int, input_dataset: ParquetDataSet, ): - def write_partition_data( - conn: duckdb.DuckDBPyConnection, partition_batch: List[Tuple[int, str]] - ) -> int: + def write_partition_data(conn: duckdb.DuckDBPyConnection, partition_batch: List[Tuple[int, str]]) -> int: total_num_rows = 0 for partition_idx, partition_filter in partition_batch: if self.use_parquet_writer: @@ -3128,15 +2707,9 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): self._write_to_partition(partition_idx, partition_data) total_num_rows += partition_data.num_rows else: - partition_filename = ( - f"{self.output_filename}-{partition_idx}.{batch_index}.parquet" - ) - partition_path = os.path.join( - self.runtime_output_abspath, partition_filename - ) - self.partitioned_datasets[partition_idx].paths.append( - partition_filename - ) + partition_filename = f"{self.output_filename}-{partition_idx}.{batch_index}.parquet" + partition_path = os.path.join(self.runtime_output_abspath, partition_filename) + self.partitioned_datasets[partition_idx].paths.append(partition_filename) perf_metrics = self.exec_query( conn, f""" @@ -3160,11 +2733,7 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): total_num_rows += perf_metrics["num output rows"] return total_num_rows - column_projection = ( - f"* EXCLUDE ({self.data_partition_column})" - if self.drop_partition_column - else "*" - ) + column_projection = f"* EXCLUDE ({self.data_partition_column})" if self.drop_partition_column else "*" partition_filters = [ ( partition_idx, @@ -3175,12 +2744,8 @@ class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): partition_batches = split_into_rows(partition_filters, self.num_workers) with contextlib.ExitStack() as stack: - db_conns = [ - stack.enter_context(conn.cursor()) for _ in range(self.num_workers) - ] - self.perf_metrics["num output rows"] += sum( - self.io_workers.map(write_partition_data, db_conns, partition_batches) - ) + db_conns = [stack.enter_context(conn.cursor()) for _ in range(self.num_workers)] + self.perf_metrics["num output rows"] += sum(self.io_workers.map(write_partition_data, db_conns, partition_batches)) elapsed_time = self.add_elapsed_time("output dump time (secs)") logger.debug(f"write partition data #{batch_index+1}: {elapsed_time:.3f} secs") @@ -3202,16 +2767,12 @@ class HashPartitionArrowTask(HashPartitionTask): table = input_dataset.to_arrow_table(max_workers=self.cpu_limit) self.perf_metrics["num input rows"] += table.num_rows elapsed_time = self.add_elapsed_time("input load time (secs)") - logger.debug( - f"load input dataset: {table.nbytes/MB:.3f}MB, {table.num_rows} rows, {elapsed_time:.3f} secs" - ) + logger.debug(f"load input dataset: {table.nbytes/MB:.3f}MB, {table.num_rows} rows, {elapsed_time:.3f} secs") if self.shuffle_only: partition_keys = table.column(self.data_partition_column) elif self.random_shuffle: - partition_keys = arrow.array( - self.numpy_random_gen.integers(self.npartitions, size=table.num_rows) - ) + partition_keys = arrow.array(self.numpy_random_gen.integers(self.npartitions, size=table.num_rows)) else: hash_columns = polars.from_arrow(table.select(self.hash_columns)) hash_values = hash_columns.hash_rows(*self.fixed_rand_seeds) @@ -3223,9 +2784,7 @@ class HashPartitionArrowTask(HashPartitionTask): elapsed_time = self.add_elapsed_time("compute time (secs)") logger.debug(f"generate partition keys: {elapsed_time:.3f} secs") - table_slice_size = max( - DEFAULT_BATCH_SIZE, min(table.num_rows // 2, 100 * 1024 * 1024) - ) + table_slice_size = max(DEFAULT_BATCH_SIZE, min(table.num_rows // 2, 100 * 1024 * 1024)) num_iterations = math.ceil(table.num_rows / table_slice_size) def write_partition_data( @@ -3237,20 +2796,14 @@ class HashPartitionArrowTask(HashPartitionTask): self._write_to_partition(partition_idx, partition_data.to_arrow()) return total_num_rows - for table_slice_idx, table_slice_offset in enumerate( - range(0, table.num_rows, table_slice_size) - ): + for table_slice_idx, table_slice_offset in enumerate(range(0, table.num_rows, table_slice_size)): table_slice = table.slice(table_slice_offset, table_slice_size) - logger.debug( - f"table slice #{table_slice_idx+1}/{num_iterations}: {table_slice.nbytes/MB:.3f}MB, {table_slice.num_rows} rows" - ) + logger.debug(f"table slice #{table_slice_idx+1}/{num_iterations}: {table_slice.nbytes/MB:.3f}MB, {table_slice.num_rows} rows") df = polars.from_arrow(table_slice) del table_slice elapsed_time = self.add_elapsed_time("compute time (secs)") - logger.debug( - f"convert from arrow table #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs" - ) + logger.debug(f"convert from arrow table #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs") partitioned_dfs = df.partition_by( [self.data_partition_column], @@ -3258,23 +2811,15 @@ class HashPartitionArrowTask(HashPartitionTask): include_key=not self.drop_partition_column, as_dict=True, ) - partitioned_dfs = [ - (partition_idx, df) for (partition_idx,), df in partitioned_dfs.items() - ] + partitioned_dfs = [(partition_idx, df) for (partition_idx,), df in partitioned_dfs.items()] del df elapsed_time = self.add_elapsed_time("compute time (secs)") - logger.debug( - f"build partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs" - ) + logger.debug(f"build partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs") partition_batches = split_into_rows(partitioned_dfs, self.num_workers) - self.perf_metrics["num output rows"] += sum( - self.io_workers.map(write_partition_data, partition_batches) - ) + self.perf_metrics["num output rows"] += sum(self.io_workers.map(write_partition_data, partition_batches)) elapsed_time = self.add_elapsed_time("output dump time (secs)") - logger.debug( - f"write partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs" - ) + logger.debug(f"write partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs") class ProjectionTask(Task): @@ -3403,27 +2948,16 @@ class DataSinkTask(Task): ) for p in paths ] - logger.info( - f"collected {len(src_paths)} files from {len(self.input_datasets)} input datasets" - ) + logger.info(f"collected {len(src_paths)} files from {len(self.input_datasets)} input datasets") if len(set(p.name for p in src_paths)) == len(src_paths): dst_paths = [runtime_output_dir / p.name for p in src_paths] else: logger.warning(f"found duplicate filenames, appending index to filename...") - dst_paths = [ - runtime_output_dir / f"{p.stem}.{idx}{p.suffix}" - for idx, p in enumerate(src_paths) - ] + dst_paths = [runtime_output_dir / f"{p.stem}.{idx}{p.suffix}" for idx, p in enumerate(src_paths)] - output_paths = ( - src_paths - if sink_type == "manifest" - else [final_output_dir / p.name for p in dst_paths] - ) - self.dataset = ParquetDataSet( - [str(p) for p in output_paths] - ) # FIXME: what if the dataset is not parquet? + output_paths = src_paths if sink_type == "manifest" else [final_output_dir / p.name for p in dst_paths] + self.dataset = ParquetDataSet([str(p) for p in output_paths]) # FIXME: what if the dataset is not parquet? def copy_file(src_path: Path, dst_path: Path): # XXX: DO NOT use shutil.{copy, copy2, copyfileobj} @@ -3433,9 +2967,7 @@ class DataSinkTask(Task): def create_link_or_copy(src_path: Path, dst_path: Path): if dst_path.exists(): - logger.warning( - f"destination path already exists, replacing {dst_path} with {src_path}" - ) + logger.warning(f"destination path already exists, replacing {dst_path} with {src_path}") dst_path.unlink(missing_ok=True) same_mount_point = str(src_path).startswith(dst_mount_point) if sink_type == "copy": @@ -3477,9 +3009,7 @@ class DataSinkTask(Task): # check the output parquet files # if any file is broken, an exception will be raised if len(dst_paths) > 0 and dst_paths[0].suffix == ".parquet": - logger.info( - f"checked dataset files and found {self.dataset.num_rows} rows" - ) + logger.info(f"checked dataset files and found {self.dataset.num_rows} rows") return True @@ -3512,9 +3042,7 @@ class ExecutionPlan(object): A directed acyclic graph (DAG) of tasks. """ - def __init__( - self, ctx: RuntimeContext, root_task: RootTask, logical_plan: "LogicalPlan" - ) -> None: + def __init__(self, ctx: RuntimeContext, root_task: RootTask, logical_plan: "LogicalPlan") -> None: from smallpond.logical.node import LogicalPlan self.ctx = ctx @@ -3612,9 +3140,7 @@ def main(): parser.add_argument("-t", "--task_id", default=None, help="Task id") parser.add_argument("-r", "--retry_count", default=0, help="Task retry count") parser.add_argument("-o", "--output_path", default=None, help="Task output path") - parser.add_argument( - "-l", "--log_level", default="DEBUG", help="Logging message level" - ) + parser.add_argument("-l", "--log_level", default="DEBUG", help="Logging message level") args = parser.parse_args() logger.remove() diff --git a/smallpond/execution/workqueue.py b/smallpond/execution/workqueue.py index 5ab2525..b290cc8 100644 --- a/smallpond/execution/workqueue.py +++ b/smallpond/execution/workqueue.py @@ -55,9 +55,7 @@ class WorkItem(object): ) -> None: self._cpu_limit = cpu_limit self._gpu_limit = gpu_limit - self._memory_limit = ( - np.int64(memory_limit) if memory_limit is not None else None - ) + self._memory_limit = np.int64(memory_limit) if memory_limit is not None else None self._cpu_boost = 1 self._memory_boost = 1 self._numa_node = None @@ -88,11 +86,7 @@ class WorkItem(object): @property def memory_limit(self) -> np.int64: - return ( - np.int64(self._memory_boost * self._memory_limit) - if self._memory_limit - else 0 - ) + return np.int64(self._memory_boost * self._memory_limit) if self._memory_limit else 0 @property def elapsed_time(self) -> float: @@ -142,13 +136,7 @@ class WorkItem(object): return ( self._memory_limit is not None and self.status == WorkStatus.CRASHED - and ( - isinstance(self.exception, (OutOfMemory, MemoryError)) - or ( - isinstance(self.exception, NonzeroExitCode) - and nonzero_exitcode_as_oom - ) - ) + and (isinstance(self.exception, (OutOfMemory, MemoryError)) or (isinstance(self.exception, NonzeroExitCode) and nonzero_exitcode_as_oom)) ) def run(self) -> bool: @@ -175,9 +163,7 @@ class WorkItem(object): else: self.status = WorkStatus.FAILED except Exception as ex: - logger.opt(exception=ex).error( - f"{repr(self)} crashed with error. node location at {self.location}" - ) + logger.opt(exception=ex).error(f"{repr(self)} crashed with error. node location at {self.location}") self.status = WorkStatus.CRASHED self.exception = ex finally: @@ -204,25 +190,18 @@ class WorkBatch(WorkItem): cpu_limit = max(w.cpu_limit for w in works) gpu_limit = max(w.gpu_limit for w in works) memory_limit = max(w.memory_limit for w in works) - super().__init__( - f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit - ) + super().__init__(f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit) self.works = works def __str__(self) -> str: - return ( - super().__str__() - + f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}" - ) + return super().__str__() + f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}" def run(self) -> bool: logger.info(f"processing {len(self.works)} works in the batch") for index, work in enumerate(self.works): work.exec_id = self.exec_id if work.exec(self.exec_cq) != WorkStatus.SUCCEED: - logger.error( - f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}" - ) + logger.error(f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}") return False logger.info(f"done {len(self.works)} works in the batch") return True @@ -375,9 +354,7 @@ class WorkQueueOnFilesystem(WorkQueue): os.rename(tempfile_path, enqueued_path) return True except OSError as err: - logger.critical( - f"failed to rename {tempfile_path} to {enqueued_path}: {err}" - ) + logger.critical(f"failed to rename {tempfile_path} to {enqueued_path}: {err}") return False @@ -405,27 +382,17 @@ def count_objects(obj, object_cnt=None, visited_objs=None, depth=0): object_cnt[class_name] = (cnt + 1, size + sys.getsizeof(obj)) key_attributes = ("__self__", "__dict__", "__slots__") - if not isinstance(obj, (bool, str, int, float, type(None))) and any( - attr_name in key_attributes for attr_name in dir(obj) - ): + if not isinstance(obj, (bool, str, int, float, type(None))) and any(attr_name in key_attributes for attr_name in dir(obj)): logger.debug(f"{' ' * depth}{class_name}@{id(obj):x}") for attr_name in dir(obj): try: - if ( - not attr_name.startswith("__") or attr_name in key_attributes - ) and not isinstance( + if (not attr_name.startswith("__") or attr_name in key_attributes) and not isinstance( getattr(obj.__class__, attr_name, None), property ): - logger.debug( - f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}" - ) - count_objects( - getattr(obj, attr_name), object_cnt, visited_objs, depth + 1 - ) + logger.debug(f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}") + count_objects(getattr(obj, attr_name), object_cnt, visited_objs, depth + 1) except Exception as ex: - logger.warning( - f"failed to get '{attr_name}' from {repr(obj)}: {ex}" - ) + logger.warning(f"failed to get '{attr_name}' from {repr(obj)}: {ex}") def main(): @@ -433,23 +400,13 @@ def main(): from smallpond.execution.task import Probe - parser = argparse.ArgumentParser( - prog="workqueue.py", description="Work Queue Reader" - ) + parser = argparse.ArgumentParser(prog="workqueue.py", description="Work Queue Reader") parser.add_argument("wq_root", help="Work queue root path") parser.add_argument("-f", "--work_filter", default="", help="Work item filter") - parser.add_argument( - "-x", "--expand_batch", action="store_true", help="Expand batched works" - ) - parser.add_argument( - "-c", "--count_object", action="store_true", help="Count number of objects" - ) - parser.add_argument( - "-n", "--top_n_class", default=20, type=int, help="Show the top n classes" - ) - parser.add_argument( - "-l", "--log_level", default="INFO", help="Logging message level" - ) + parser.add_argument("-x", "--expand_batch", action="store_true", help="Expand batched works") + parser.add_argument("-c", "--count_object", action="store_true", help="Count number of objects") + parser.add_argument("-n", "--top_n_class", default=20, type=int, help="Show the top n classes") + parser.add_argument("-l", "--log_level", default="INFO", help="Logging message level") args = parser.parse_args() logger.remove() @@ -468,9 +425,7 @@ def main(): if args.count_object: object_cnt = {} count_objects(work, object_cnt) - sorted_counts = sorted( - [(v, k) for k, v in object_cnt.items()], reverse=True - ) + sorted_counts = sorted([(v, k) for k, v in object_cnt.items()], reverse=True) for count, class_name in sorted_counts[: args.top_n_class]: logger.info(f" {class_name}: {count}") diff --git a/smallpond/io/arrow.py b/smallpond/io/arrow.py index 9c47cae..d315ba6 100644 --- a/smallpond/io/arrow.py +++ b/smallpond/io/arrow.py @@ -44,11 +44,7 @@ class RowRange: @property def estimated_data_size(self) -> int: """The estimated uncompressed data size in bytes.""" - return ( - self.data_size * self.num_rows // self.file_num_rows - if self.file_num_rows > 0 - else 0 - ) + return self.data_size * self.num_rows // self.file_num_rows if self.file_num_rows > 0 else 0 def take(self, num_rows: int) -> "RowRange": """ @@ -62,9 +58,7 @@ class RowRange: return head @staticmethod - def partition_by_rows( - row_ranges: List["RowRange"], npartition: int - ) -> List[List["RowRange"]]: + def partition_by_rows(row_ranges: List["RowRange"], npartition: int) -> List[List["RowRange"]]: """Evenly split a list of row ranges into `npartition` partitions.""" # NOTE: `row_ranges` should not be modified by this function row_ranges = copy.deepcopy(row_ranges) @@ -128,9 +122,7 @@ def convert_types_to_large_string(schema: arrow.Schema) -> arrow.Schema: new_fields = [] for field in schema: new_type = convert_type_to_large(field.type) - new_field = arrow.field( - field.name, new_type, nullable=field.nullable, metadata=field.metadata - ) + new_field = arrow.field(field.name, new_type, nullable=field.nullable, metadata=field.metadata) new_fields.append(new_field) return arrow.schema(new_fields, metadata=schema.metadata) @@ -151,11 +143,7 @@ def filter_schema( if included_cols is not None: fields = [schema.field(col_name) for col_name in included_cols] if excluded_cols is not None: - fields = [ - schema.field(col_name) - for col_name in schema.names - if col_name not in excluded_cols - ] + fields = [schema.field(col_name) for col_name in schema.names if col_name not in excluded_cols] return arrow.schema(fields, metadata=schema.metadata) @@ -172,9 +160,7 @@ def _iter_record_batches( current_offset = 0 required_l, required_r = offset, offset + length - for batch in file.iter_batches( - batch_size=batch_size, columns=columns, use_threads=False - ): + for batch in file.iter_batches(batch_size=batch_size, columns=columns, use_threads=False): current_l, current_r = current_offset, current_offset + batch.num_rows # check if intersection is null if current_r <= required_l: @@ -184,9 +170,7 @@ def _iter_record_batches( else: intersection_l = max(required_l, current_l) intersection_r = min(required_r, current_r) - trimmed = batch.slice( - intersection_l - current_offset, intersection_r - intersection_l - ) + trimmed = batch.slice(intersection_l - current_offset, intersection_r - intersection_l) assert ( trimmed.num_rows == intersection_r - intersection_l ), f"trimmed.num_rows {trimmed.num_rows} != batch_length {intersection_r - intersection_l}" @@ -204,9 +188,7 @@ def build_batch_reader_from_files( ) -> arrow.RecordBatchReader: assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list" schema = _read_schema_from_file(paths_or_ranges[0], columns, filesystem) - iterator = _iter_record_batches_from_files( - paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem - ) + iterator = _iter_record_batches_from_files(paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem) return arrow.RecordBatchReader.from_batches(schema, iterator) @@ -216,9 +198,7 @@ def _read_schema_from_file( filesystem: fsspec.AbstractFileSystem = None, ) -> arrow.Schema: path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range - schema = parquet.read_schema( - filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem - ) + schema = parquet.read_schema(filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem) if columns is not None: assert all( c in schema.names for c in columns @@ -253,9 +233,7 @@ def _iter_record_batches_from_files( yield from table.combine_chunks().to_batches(batch_size) for path_or_range in paths_or_ranges: - path = ( - path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range - ) + path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range with parquet.ParquetFile( filesystem.unstrip_protocol(path) if filesystem else path, buffer_size=16 * MB, @@ -265,23 +243,16 @@ def _iter_record_batches_from_files( offset, length = path_or_range.begin, path_or_range.num_rows else: offset, length = 0, file.metadata.num_rows - for batch in _iter_record_batches( - file, columns, offset, length, batch_size - ): + for batch in _iter_record_batches(file, columns, offset, length, batch_size): batch_size_exceeded = batch.num_rows + buffered_rows >= batch_size - batch_byte_size_exceeded = ( - max_batch_byte_size is not None - and batch.nbytes + buffered_bytes >= max_batch_byte_size - ) + batch_byte_size_exceeded = max_batch_byte_size is not None and batch.nbytes + buffered_bytes >= max_batch_byte_size if not batch_size_exceeded and not batch_byte_size_exceeded: buffered_batches.append(batch) buffered_rows += batch.num_rows buffered_bytes += batch.nbytes else: if batch_size_exceeded: - buffered_batches.append( - batch.slice(0, batch_size - buffered_rows) - ) + buffered_batches.append(batch.slice(0, batch_size - buffered_rows)) batch = batch.slice(batch_size - buffered_rows) if buffered_batches: yield from combine_buffered_batches(buffered_batches) @@ -298,9 +269,7 @@ def read_parquet_files_into_table( columns: List[str] = None, filesystem: fsspec.AbstractFileSystem = None, ) -> arrow.Table: - batch_reader = build_batch_reader_from_files( - paths_or_ranges, columns=columns, filesystem=filesystem - ) + batch_reader = build_batch_reader_from_files(paths_or_ranges, columns=columns, filesystem=filesystem) return batch_reader.read_all() @@ -312,37 +281,22 @@ def load_from_parquet_files( ) -> arrow.Table: start_time = time.time() assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list" - paths = [ - path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range - for path_or_range in paths_or_ranges - ] + paths = [path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range for path_or_range in paths_or_ranges] total_compressed_size = sum( - ( - path_or_range.data_size - if isinstance(path_or_range, RowRange) - else os.path.getsize(path_or_range) - ) - for path_or_range in paths_or_ranges - ) - logger.debug( - f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}..." + (path_or_range.data_size if isinstance(path_or_range, RowRange) else os.path.getsize(path_or_range)) for path_or_range in paths_or_ranges ) + logger.debug(f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}...") num_workers = min(len(paths), max_workers) with ThreadPoolExecutor(num_workers) as pool: running_works = [ - pool.submit(read_parquet_files_into_table, batch, columns, filesystem) - for batch in split_into_rows(paths_or_ranges, num_workers) + pool.submit(read_parquet_files_into_table, batch, columns, filesystem) for batch in split_into_rows(paths_or_ranges, num_workers) ] tables = [work.result() for work in running_works] - logger.debug( - f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)" - ) + logger.debug(f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)") return arrow.concat_tables(tables) -def parquet_write_table( - table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args -) -> int: +def parquet_write_table(table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args) -> int: if filesystem is not None: return parquet.write_table( table, @@ -388,10 +342,7 @@ def dump_to_parquet_files( num_workers = min(len(batches), max_workers) num_tables = max(math.ceil(table.nbytes / MAX_PARQUET_FILE_BYTES), num_workers) logger.debug(f"evenly distributed {len(batches)} batches into {num_tables} files") - tables = [ - arrow.Table.from_batches(batch, table.schema) - for batch in split_into_rows(batches, num_tables) - ] + tables = [arrow.Table.from_batches(batch, table.schema) for batch in split_into_rows(batches, num_tables)] assert sum(t.num_rows for t in tables) == table.num_rows logger.debug(f"writing {len(tables)} files to {output_dir}") @@ -413,7 +364,5 @@ def dump_to_parquet_files( ] assert all(work.result() or True for work in running_works) - logger.debug( - f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)" - ) + logger.debug(f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)") return True diff --git a/smallpond/io/filesystem.py b/smallpond/io/filesystem.py index c76d406..0d8f942 100644 --- a/smallpond/io/filesystem.py +++ b/smallpond/io/filesystem.py @@ -44,9 +44,7 @@ def remove_path(path: str): os.symlink(realpath, link) return except Exception as ex: - logger.opt(exception=ex).debug( - f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')" - ) + logger.opt(exception=ex).debug(f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')") shutil.rmtree(realpath, ignore_errors=True) @@ -94,9 +92,7 @@ def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int: raise except Exception as ex: trace_str, trace_err = get_pickle_trace(obj) - logger.opt(exception=ex).error( - f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}" - ) + logger.opt(exception=ex).error(f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}") if trace_err is None: raise else: @@ -107,9 +103,7 @@ def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int: if atomic_write: directory, filename = os.path.split(path) - with tempfile.NamedTemporaryFile( - "wb", buffering=buffering, dir=directory, prefix=filename, delete=False - ) as fout: + with tempfile.NamedTemporaryFile("wb", buffering=buffering, dir=directory, prefix=filename, delete=False) as fout: write_to_file(fout) fout.seek(0, os.SEEK_END) size = fout.tell() diff --git a/smallpond/logical/dataset.py b/smallpond/logical/dataset.py index a507f94..123fe53 100644 --- a/smallpond/logical/dataset.py +++ b/smallpond/logical/dataset.py @@ -206,9 +206,7 @@ class DataSet(object): resolved_paths.append(path) if wildcard_paths: if len(wildcard_paths) == 1: - expanded_paths = glob.glob( - wildcard_paths[0], recursive=self.recursive - ) + expanded_paths = glob.glob(wildcard_paths[0], recursive=self.recursive) else: logger.debug( "resolving {} paths with wildcards in {}", @@ -247,9 +245,7 @@ class DataSet(object): if self.root_dir is None: self._absolute_paths = sorted(self.paths) else: - self._absolute_paths = [ - os.path.join(self.root_dir, p) for p in sorted(self.paths) - ] + self._absolute_paths = [os.path.join(self.root_dir, p) for p in sorted(self.paths)] return self._absolute_paths def sql_query_fragment( @@ -340,23 +336,15 @@ class DataSet(object): return file_partitions @functools.lru_cache - def partition_by_files( - self, npartition: int, random_shuffle: bool = False - ) -> "List[DataSet]": + def partition_by_files(self, npartition: int, random_shuffle: bool = False) -> "List[DataSet]": """ Evenly split into `npartition` datasets by files. """ assert npartition > 0, f"npartition has negative value: {npartition}" if npartition > self.num_files: - logger.debug( - f"number of partitions {npartition} is greater than the number of files {self.num_files}" - ) + logger.debug(f"number of partitions {npartition} is greater than the number of files {self.num_files}") - resolved_paths = ( - random.sample(self.resolved_paths, len(self.resolved_paths)) - if random_shuffle - else self.resolved_paths - ) + resolved_paths = random.sample(self.resolved_paths, len(self.resolved_paths)) if random_shuffle else self.resolved_paths evenly_split_groups = split_into_rows(resolved_paths, npartition) num_paths_in_groups = list(map(len, evenly_split_groups)) @@ -367,11 +355,7 @@ class DataSet(object): logger.debug( f"created {npartition} file partitions (min #files: {min(num_paths_in_groups)}, max #files: {max(num_paths_in_groups)}, avg #files: {sum(num_paths_in_groups)/len(num_paths_in_groups):.3f}) from {self}" ) - return ( - random.sample(file_partitions, len(file_partitions)) - if random_shuffle - else file_partitions - ) + return random.sample(file_partitions, len(file_partitions)) if random_shuffle else file_partitions class PartitionedDataSet(DataSet): @@ -463,9 +447,7 @@ class CsvDataSet(DataSet): union_by_name=False, ) -> None: super().__init__(paths, root_dir, recursive, columns, union_by_name) - assert isinstance( - schema, OrderedDict - ), f"type of csv schema is not OrderedDict: {type(schema)}" + assert isinstance(schema, OrderedDict), f"type of csv schema is not OrderedDict: {type(schema)}" self.schema = schema self.delim = delim self.max_line_size = max_line_size @@ -492,14 +474,8 @@ class CsvDataSet(DataSet): filesystem: fsspec.AbstractFileSystem = None, conn: duckdb.DuckDBPyConnection = None, ) -> str: - schema_str = ", ".join( - map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()) - ) - max_line_size_str = ( - f"max_line_size={self.max_line_size}, " - if self.max_line_size is not None - else "" - ) + schema_str = ", ".join(map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items())) + max_line_size_str = f"max_line_size={self.max_line_size}, " if self.max_line_size is not None else "" return ( f"( select {self._column_str} from read_csv([ {self._resolved_path_str} ], delim='{self.delim}', columns={{ {schema_str} }}, header={self.header}, " f"{max_line_size_str} parallel={self.parallel}, union_by_name={self.union_by_name}) )" @@ -552,9 +528,7 @@ class JsonDataSet(DataSet): filesystem: fsspec.AbstractFileSystem = None, conn: duckdb.DuckDBPyConnection = None, ) -> str: - schema_str = ", ".join( - map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()) - ) + schema_str = ", ".join(map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items())) return ( f"( select {self._column_str} from read_json([ {self._resolved_path_str} ], format='{self.format}', columns={{ {schema_str} }}, " f"maximum_object_size={self.max_object_size}, union_by_name={self.union_by_name}) )" @@ -614,11 +588,7 @@ class ParquetDataSet(DataSet): ) # merge row ranges if any dataset has resolved row ranges if any(dataset._resolved_row_ranges is not None for dataset in datasets): - dataset._resolved_row_ranges = [ - row_range - for dataset in datasets - for row_range in dataset.resolved_row_ranges - ] + dataset._resolved_row_ranges = [row_range for dataset in datasets for row_range in dataset.resolved_row_ranges] return dataset @staticmethod @@ -655,10 +625,7 @@ class ParquetDataSet(DataSet): # read parquet metadata to get number of rows metadata = parquet.read_metadata(path) num_rows = metadata.num_rows - uncompressed_data_size = sum( - metadata.row_group(i).total_byte_size - for i in range(metadata.num_row_groups) - ) + uncompressed_data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups)) return RowRange( path, data_size=uncompressed_data_size, @@ -667,20 +634,14 @@ class ParquetDataSet(DataSet): end=num_rows, ) - with ThreadPoolExecutor( - max_workers=min(32, len(self.resolved_paths)) - ) as pool: - self._resolved_row_ranges = list( - pool.map(resolve_row_range, self.resolved_paths) - ) + with ThreadPoolExecutor(max_workers=min(32, len(self.resolved_paths))) as pool: + self._resolved_row_ranges = list(pool.map(resolve_row_range, self.resolved_paths)) return self._resolved_row_ranges @property def num_rows(self) -> int: if self._resolved_num_rows is None: - self._resolved_num_rows = sum( - row_range.num_rows for row_range in self.resolved_row_ranges - ) + self._resolved_num_rows = sum(row_range.num_rows for row_range in self.resolved_row_ranges) return self._resolved_num_rows @property @@ -695,29 +656,19 @@ class ParquetDataSet(DataSet): """ Return the estimated data size in bytes. """ - return sum( - row_range.estimated_data_size for row_range in self.resolved_row_ranges - ) + return sum(row_range.estimated_data_size for row_range in self.resolved_row_ranges) def sql_query_fragment( self, filesystem: fsspec.AbstractFileSystem = None, conn: duckdb.DuckDBPyConnection = None, ) -> str: - extra_parameters = ( - "".join(f", {col}=true" for col in self.generated_columns) - if self.generated_columns - else "" - ) + extra_parameters = "".join(f", {col}=true" for col in self.generated_columns) if self.generated_columns else "" parquet_file_queries = [] full_row_ranges = [] for row_range in self.resolved_row_ranges: - path = ( - filesystem.unstrip_protocol(row_range.path) - if filesystem - else row_range.path - ) + path = filesystem.unstrip_protocol(row_range.path) if filesystem else row_range.path if row_range.num_rows == row_range.file_num_rows: full_row_ranges.append(row_range) else: @@ -744,9 +695,7 @@ class ParquetDataSet(DataSet): full_row_ranges[largest_index], full_row_ranges[0], ) - parquet_file_str = ",\n ".join( - map(lambda x: f"'{x.path}'", full_row_ranges) - ) + parquet_file_str = ",\n ".join(map(lambda x: f"'{x.path}'", full_row_ranges)) parquet_file_queries.insert( 0, f""" @@ -771,11 +720,7 @@ class ParquetDataSet(DataSet): tables = [] if self.resolved_row_ranges: - tables.append( - load_from_parquet_files( - self.resolved_row_ranges, self.columns, max_workers, filesystem - ) - ) + tables.append(load_from_parquet_files(self.resolved_row_ranges, self.columns, max_workers, filesystem)) return arrow.concat_tables(tables) def to_batch_reader( @@ -795,18 +740,14 @@ class ParquetDataSet(DataSet): ) @functools.lru_cache - def partition_by_files( - self, npartition: int, random_shuffle: bool = False - ) -> "List[ParquetDataSet]": + def partition_by_files(self, npartition: int, random_shuffle: bool = False) -> "List[ParquetDataSet]": if self._resolved_row_ranges is not None: return self.partition_by_rows(npartition, random_shuffle) else: return super().partition_by_files(npartition, random_shuffle) @functools.lru_cache - def partition_by_rows( - self, npartition: int, random_shuffle: bool = False - ) -> "List[ParquetDataSet]": + def partition_by_rows(self, npartition: int, random_shuffle: bool = False) -> "List[ParquetDataSet]": """ Evenly split the dataset into `npartition` partitions by rows. If `random_shuffle` is True, shuffle the files before partitioning. @@ -814,11 +755,7 @@ class ParquetDataSet(DataSet): assert npartition > 0, f"npartition has negative value: {npartition}" resolved_row_ranges = self.resolved_row_ranges - resolved_row_ranges = ( - random.sample(resolved_row_ranges, len(resolved_row_ranges)) - if random_shuffle - else resolved_row_ranges - ) + resolved_row_ranges = random.sample(resolved_row_ranges, len(resolved_row_ranges)) if random_shuffle else resolved_row_ranges def create_dataset(row_ranges: List[RowRange]) -> ParquetDataSet: row_ranges = sorted(row_ranges, key=lambda x: x.path) @@ -833,12 +770,7 @@ class ParquetDataSet(DataSet): dataset._resolved_row_ranges = row_ranges return dataset - return [ - create_dataset(row_ranges) - for row_ranges in RowRange.partition_by_rows( - resolved_row_ranges, npartition - ) - ] + return [create_dataset(row_ranges) for row_ranges in RowRange.partition_by_rows(resolved_row_ranges, npartition)] @functools.lru_cache def partition_by_size(self, max_partition_size: int) -> "List[ParquetDataSet]": @@ -847,16 +779,12 @@ class ParquetDataSet(DataSet): """ if self.empty: return [] - estimated_data_size = sum( - row_range.estimated_data_size for row_range in self.resolved_row_ranges - ) + estimated_data_size = sum(row_range.estimated_data_size for row_range in self.resolved_row_ranges) npartition = estimated_data_size // max_partition_size + 1 return self.partition_by_rows(npartition) @staticmethod - def _read_partition_key( - path: str, data_partition_column: str, hive_partitioning: bool - ) -> int: + def _read_partition_key(path: str, data_partition_column: str, hive_partitioning: bool) -> int: """ Get the partition key of the parquet file. @@ -874,9 +802,7 @@ class ParquetDataSet(DataSet): try: return int(key) except ValueError: - logger.error( - f"cannot parse partition key '{data_partition_column}' of {path} from: {key}" - ) + logger.error(f"cannot parse partition key '{data_partition_column}' of {path} from: {key}") raise if hive_partitioning: @@ -884,9 +810,7 @@ class ParquetDataSet(DataSet): for part in path.split(os.path.sep): if part.startswith(path_part_prefix): return parse_partition_key(part[len(path_part_prefix) :]) - raise RuntimeError( - f"cannot extract hive partition key '{data_partition_column}' from path: {path}" - ) + raise RuntimeError(f"cannot extract hive partition key '{data_partition_column}' from path: {path}") with parquet.ParquetFile(path) as file: kv_metadata = file.schema_arrow.metadata or file.metadata.metadata @@ -896,36 +820,22 @@ class ParquetDataSet(DataSet): if key == PARQUET_METADATA_KEY_PREFIX + data_partition_column: return parse_partition_key(val) if file.metadata.num_rows == 0: - logger.warning( - f"cannot read partition keys from empty parquet file: {path}" - ) + logger.warning(f"cannot read partition keys from empty parquet file: {path}") return None - for batch in file.iter_batches( - batch_size=128, columns=[data_partition_column], use_threads=False - ): - assert ( - data_partition_column in batch.column_names - ), f"cannot find column '{data_partition_column}' in {batch.column_names}" - assert ( - batch.num_columns == 1 - ), f"unexpected num of columns: {batch.column_names}" + for batch in file.iter_batches(batch_size=128, columns=[data_partition_column], use_threads=False): + assert data_partition_column in batch.column_names, f"cannot find column '{data_partition_column}' in {batch.column_names}" + assert batch.num_columns == 1, f"unexpected num of columns: {batch.column_names}" uniq_partition_keys = set(batch.columns[0].to_pylist()) - assert ( - uniq_partition_keys and len(uniq_partition_keys) == 1 - ), f"partition keys found in {path} not unique: {uniq_partition_keys}" + assert uniq_partition_keys and len(uniq_partition_keys) == 1, f"partition keys found in {path} not unique: {uniq_partition_keys}" return uniq_partition_keys.pop() - def load_partitioned_datasets( - self, npartition: int, data_partition_column: str, hive_partitioning=False - ) -> "List[ParquetDataSet]": + def load_partitioned_datasets(self, npartition: int, data_partition_column: str, hive_partitioning=False) -> "List[ParquetDataSet]": """ Split the dataset into a list of partitioned datasets. """ assert npartition > 0, f"npartition has negative value: {npartition}" if npartition > self.num_files: - logger.debug( - f"number of partitions {npartition} is greater than the number of files {self.num_files}" - ) + logger.debug(f"number of partitions {npartition} is greater than the number of files {self.num_files}") file_partitions: List[ParquetDataSet] = self._init_file_partitions(npartition) for dataset in file_partitions: @@ -940,17 +850,13 @@ class ParquetDataSet(DataSet): with ThreadPoolExecutor(min(32, len(self.resolved_paths))) as pool: partition_keys = pool.map( - lambda path: ParquetDataSet._read_partition_key( - path, data_partition_column, hive_partitioning - ), + lambda path: ParquetDataSet._read_partition_key(path, data_partition_column, hive_partitioning), self.resolved_paths, ) for row_range, partition_key in zip(self.resolved_row_ranges, partition_keys): if partition_key is not None: - assert ( - 0 <= partition_key <= npartition - ), f"invalid partition key {partition_key} found in {row_range.path}" + assert 0 <= partition_key <= npartition, f"invalid partition key {partition_key} found in {row_range.path}" dataset = file_partitions[partition_key] dataset.paths.append(row_range.path) dataset._absolute_paths.append(row_range.path) @@ -964,20 +870,14 @@ class ParquetDataSet(DataSet): """ Remove empty parquet files from the dataset. """ - new_row_ranges = [ - row_range - for row_range in self.resolved_row_ranges - if row_range.num_rows > 0 - ] + new_row_ranges = [row_range for row_range in self.resolved_row_ranges if row_range.num_rows > 0] if len(new_row_ranges) == 0: # keep at least one file to avoid empty dataset new_row_ranges = self.resolved_row_ranges[:1] if len(new_row_ranges) == len(self.resolved_row_ranges): # no empty files found return - logger.info( - f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}" - ) + logger.info(f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}") self._resolved_row_ranges = new_row_ranges self._resolved_paths = [row_range.path for row_range in new_row_ranges] self._absolute_paths = self._resolved_paths @@ -997,9 +897,7 @@ class SqlQueryDataSet(DataSet): def __init__( self, sql_query: str, - query_builder: Callable[ - [duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str - ] = None, + query_builder: Callable[[duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str] = None, ) -> None: super().__init__([]) self.sql_query = sql_query @@ -1007,9 +905,7 @@ class SqlQueryDataSet(DataSet): @property def num_rows(self) -> int: - num_rows = duckdb.sql( - f"select count(*) as num_rows from {self.sql_query_fragment()}" - ).fetchall() + num_rows = duckdb.sql(f"select count(*) as num_rows from {self.sql_query_fragment()}").fetchall() return num_rows[0][0] def sql_query_fragment( @@ -1017,11 +913,7 @@ class SqlQueryDataSet(DataSet): filesystem: fsspec.AbstractFileSystem = None, conn: duckdb.DuckDBPyConnection = None, ) -> str: - sql_query = ( - self.sql_query - if self.query_builder is None - else self.query_builder(conn, filesystem) - ) + sql_query = self.sql_query if self.query_builder is None else self.query_builder(conn, filesystem) return f"( {sql_query} )" diff --git a/smallpond/logical/node.py b/smallpond/logical/node.py index 0ebab8f..a89016d 100644 --- a/smallpond/logical/node.py +++ b/smallpond/logical/node.py @@ -132,9 +132,7 @@ class Context(object): ------- The unique function name. """ - self.udfs[name] = PythonUDFContext( - name, func, params, return_type, use_arrow_type - ) + self.udfs[name] = PythonUDFContext(name, func, params, return_type, use_arrow_type) return name def create_external_module(self, module_path: str, name: str = None) -> str: @@ -213,18 +211,10 @@ class Node(object): This is a resource requirement specified by the user and used to guide task scheduling. smallpond does NOT enforce this limit. """ - assert isinstance( - input_deps, Iterable - ), f"input_deps is not iterable: {input_deps}" - assert all( - isinstance(node, Node) for node in input_deps - ), f"some of input_deps are not instances of Node: {input_deps}" - assert output_name is None or re.match( - "[a-zA-Z0-9_]+", output_name - ), f"output_name has invalid format: {output_name}" - assert output_path is None or os.path.isabs( - output_path - ), f"output_path is not an absolute path: {output_path}" + assert isinstance(input_deps, Iterable), f"input_deps is not iterable: {input_deps}" + assert all(isinstance(node, Node) for node in input_deps), f"some of input_deps are not instances of Node: {input_deps}" + assert output_name is None or re.match("[a-zA-Z0-9_]+", output_name), f"output_name has invalid format: {output_name}" + assert output_path is None or os.path.isabs(output_path), f"output_path is not an absolute path: {output_path}" self.ctx = ctx self.id = self.ctx._new_node_id() self.input_deps = input_deps @@ -238,10 +228,7 @@ class Node(object): self.perf_metrics: Dict[str, List[float]] = defaultdict(list) # record the location where the node is constructed in user code frame = next( - frame - for frame in reversed(traceback.extract_stack()) - if frame.filename != __file__ - and not frame.filename.endswith("/dataframe.py") + frame for frame in reversed(traceback.extract_stack()) if frame.filename != __file__ and not frame.filename.endswith("/dataframe.py") ) self.location = f"{frame.filename}:{frame.lineno}" @@ -290,9 +277,7 @@ class Node(object): values = self.perf_metrics[name] min, max, avg = np.min(values), np.max(values), np.average(values) p50, p75, p95, p99 = np.percentile(values, (50, 75, 95, 99)) - self.perf_stats[name] = PerfStats( - len(values), sum(values), min, max, avg, p50, p75, p95, p99 - ) + self.perf_stats[name] = PerfStats(len(values), sum(values), min, max, avg, p50, p75, p95, p99) return self.perf_stats[name] @property @@ -378,9 +363,7 @@ class DataSinkNode(Node): "link_or_copy", "manifest", ), f"invalid sink type: {type}" - super().__init__( - ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0 - ) + super().__init__(ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0) self.type: DataSinkType = "manifest" if manifest_only else type self.is_final_node = is_final_node @@ -402,12 +385,7 @@ class DataSinkNode(Node): if self.type == "copy" or self.type == "link_or_copy": # so we create two phase tasks: # phase1: copy data to a temp directory, for each input partition in parallel - input_deps = [ - self._create_phase1_task( - runtime_ctx, task, [PartitionInfo(i, len(input_deps))] - ) - for i, task in enumerate(input_deps) - ] + input_deps = [self._create_phase1_task(runtime_ctx, task, [PartitionInfo(i, len(input_deps))]) for i, task in enumerate(input_deps)] # phase2: resolve file name conflicts, hard link files, create manifest file, and clean up temp directory return DataSinkTask( runtime_ctx, @@ -445,9 +423,7 @@ class DataSinkNode(Node): input_dep: Task, partition_infos: List[PartitionInfo], ) -> DataSinkTask: - return DataSinkTask( - runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type - ) + return DataSinkTask(runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type) class PythonScriptNode(Node): @@ -467,9 +443,7 @@ class PythonScriptNode(Node): ctx: Context, input_deps: Tuple[Node, ...], *, - process_func: Optional[ - Callable[[RuntimeContext, List[DataSet], str], bool] - ] = None, + process_func: Optional[Callable[[RuntimeContext, List[DataSet], str], bool]] = None, output_name: Optional[str] = None, output_path: Optional[str] = None, cpu_limit: int = 1, @@ -646,9 +620,7 @@ class ArrowComputeNode(Node): gpu_limit, memory_limit, ) - self.parquet_row_group_size = ( - parquet_row_group_size or self.default_row_group_size - ) + self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level @@ -708,9 +680,7 @@ class ArrowComputeNode(Node): """ return ArrowComputeTask(*args, **kwargs) - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: """ Put user-defined code here. @@ -749,9 +719,7 @@ class ArrowStreamNode(Node): ctx: Context, input_deps: Tuple[Node, ...], *, - process_func: Callable[ - [RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table] - ] = None, + process_func: Callable[[RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table]] = None, background_io_thread=True, streaming_batch_size: int = None, secs_checkpoint_interval: int = None, @@ -816,12 +784,9 @@ class ArrowStreamNode(Node): self.background_io_thread = background_io_thread and self.cpu_limit > 1 self.streaming_batch_size = streaming_batch_size or self.default_batch_size self.secs_checkpoint_interval = secs_checkpoint_interval or math.ceil( - self.default_secs_checkpoint_interval - / min(6, self.gpu_limit + 2, self.cpu_limit) - ) - self.parquet_row_group_size = ( - parquet_row_group_size or self.default_row_group_size + self.default_secs_checkpoint_interval / min(6, self.gpu_limit + 2, self.cpu_limit) ) + self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level @@ -890,9 +855,7 @@ class ArrowStreamNode(Node): """ return ArrowStreamTask(*args, **kwargs) - def process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[arrow.Table]: + def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: """ Put user-defined code here. @@ -918,9 +881,7 @@ class ArrowBatchNode(ArrowStreamNode): def spawn(self, *args, **kwargs) -> ArrowBatchTask: return ArrowBatchTask(*args, **kwargs) - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: raise NotImplementedError @@ -932,9 +893,7 @@ class PandasComputeNode(ArrowComputeNode): def spawn(self, *args, **kwargs) -> PandasComputeTask: return PandasComputeTask(*args, **kwargs) - def process( - self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] - ) -> pd.DataFrame: + def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame: raise NotImplementedError @@ -946,9 +905,7 @@ class PandasBatchNode(ArrowStreamNode): def spawn(self, *args, **kwargs) -> PandasBatchTask: return PandasBatchTask(*args, **kwargs) - def process( - self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] - ) -> pd.DataFrame: + def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame: raise NotImplementedError @@ -1064,13 +1021,8 @@ class SqlEngineNode(Node): cpu_limit = cpu_limit or self.default_cpu_limit memory_limit = memory_limit or self.default_memory_limit if udfs is not None: - if ( - self.max_udf_cpu_limit is not None - and cpu_limit > self.max_udf_cpu_limit - ): - warnings.warn( - f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}" - ) + if self.max_udf_cpu_limit is not None and cpu_limit > self.max_udf_cpu_limit: + warnings.warn(f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}") cpu_limit = self.max_udf_cpu_limit memory_limit = None if relax_memory_if_oom is not None: @@ -1080,10 +1032,7 @@ class SqlEngineNode(Node): stacklevel=3, ) - assert isinstance(sql_query, str) or ( - isinstance(sql_query, Iterable) - and all(isinstance(q, str) for q in sql_query) - ) + assert isinstance(sql_query, str) or (isinstance(sql_query, Iterable) and all(isinstance(q, str) for q in sql_query)) super().__init__( ctx, input_deps, @@ -1092,17 +1041,13 @@ class SqlEngineNode(Node): cpu_limit=cpu_limit, memory_limit=memory_limit, ) - self.sql_queries = ( - [sql_query] if isinstance(sql_query, str) else list(sql_query) - ) - self.udfs = [ - ctx.create_duckdb_extension(path) for path in extension_paths or [] - ] + [ctx.create_external_module(path) for path in udf_module_paths or []] + self.sql_queries = [sql_query] if isinstance(sql_query, str) else list(sql_query) + self.udfs = [ctx.create_duckdb_extension(path) for path in extension_paths or []] + [ + ctx.create_external_module(path) for path in udf_module_paths or [] + ] for udf in udfs or []: if isinstance(udf, UserDefinedFunction): - name = ctx.create_function( - udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type - ) + name = ctx.create_function(udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type) else: assert isinstance(udf, str), f"udf must be a string: {udf}" if udf in ctx.udfs: @@ -1120,9 +1065,7 @@ class SqlEngineNode(Node): self.materialize_in_memory = materialize_in_memory self.batched_processing = batched_processing and len(input_deps) == 1 self.enable_temp_directory = enable_temp_directory - self.parquet_row_group_size = ( - parquet_row_group_size or self.default_row_group_size - ) + self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level @@ -1130,17 +1073,11 @@ class SqlEngineNode(Node): self.memory_overcommit_ratio = memory_overcommit_ratio def __str__(self) -> str: - return ( - super().__str__() - + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}" - ) + return super().__str__() + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}" @property def oneline_query(self) -> str: - return "; ".join( - " ".join(filter(None, map(str.strip, query.splitlines()))) - for query in self.sql_queries - ) + return "; ".join(" ".join(filter(None, map(str.strip, query.splitlines()))) for query in self.sql_queries) @Node.task_factory def create_task( @@ -1224,12 +1161,8 @@ class ConsolidateNode(Node): dimensions Partitions would be grouped by these `dimensions` and consolidated into larger partitions. """ - assert isinstance( - dimensions, Iterable - ), f"dimensions is not iterable: {dimensions}" - assert all( - isinstance(dim, str) for dim in dimensions - ), f"some dimensions are not strings: {dimensions}" + assert isinstance(dimensions, Iterable), f"dimensions is not iterable: {dimensions}" + assert all(isinstance(dim, str) for dim in dimensions), f"some dimensions are not strings: {dimensions}" super().__init__(ctx, [input_dep]) self.dimensions = set(list(dimensions) + [PartitionInfo.toplevel_dimension]) @@ -1283,29 +1216,16 @@ class PartitionNode(Node): See unit tests in `test/test_partition.py`. For nested partition see `test_nested_partition`. Why nested partition? See **5.1 Partial Partitioning** of [Advanced partitioning techniques for massively distributed computation](https://dl.acm.org/doi/10.1145/2213836.2213839). """ - assert isinstance( - npartitions, int - ), f"npartitions is not an integer: {npartitions}" - assert dimension is None or re.match( - "[a-zA-Z0-9_]+", dimension - ), f"dimension has invalid format: {dimension}" - assert not ( - nested and dimension is None - ), f"nested partition should have dimension" - super().__init__( - ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit - ) + assert isinstance(npartitions, int), f"npartitions is not an integer: {npartitions}" + assert dimension is None or re.match("[a-zA-Z0-9_]+", dimension), f"dimension has invalid format: {dimension}" + assert not (nested and dimension is None), f"nested partition should have dimension" + super().__init__(ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit) self.npartitions = npartitions - self.dimension = ( - dimension if dimension is not None else PartitionInfo.default_dimension - ) + self.dimension = dimension if dimension is not None else PartitionInfo.default_dimension self.nested = nested def __str__(self) -> str: - return ( - super().__str__() - + f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}" - ) + return super().__str__() + f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}" @Node.task_factory def create_producer_task( @@ -1441,12 +1361,8 @@ class UserDefinedPartitionNode(PartitionNode): class UserPartitionedDataSourceNode(UserDefinedPartitionNode): max_num_producer_tasks = 1 - def __init__( - self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None - ) -> None: - assert isinstance(partitioned_datasets, Iterable) and all( - isinstance(dataset, DataSet) for dataset in partitioned_datasets - ) + def __init__(self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None) -> None: + assert isinstance(partitioned_datasets, Iterable) and all(isinstance(dataset, DataSet) for dataset in partitioned_datasets) super().__init__( ctx, [DataSourceNode(ctx, dataset=None)], @@ -1507,10 +1423,7 @@ class EvenlyDistributedPartitionNode(PartitionNode): self.random_shuffle = random_shuffle def __str__(self) -> str: - return ( - super().__str__() - + f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}" - ) + return super().__str__() + f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}" @Node.task_factory def create_producer_task( @@ -1551,9 +1464,7 @@ class LoadPartitionedDataSetNode(PartitionNode): cpu_limit: int = 1, memory_limit: Optional[int] = None, ) -> None: - assert ( - dimension or data_partition_column - ), f"Both 'dimension' and 'data_partition_column' are none or empty" + assert dimension or data_partition_column, f"Both 'dimension' and 'data_partition_column' are none or empty" super().__init__( ctx, input_deps, @@ -1567,10 +1478,7 @@ class LoadPartitionedDataSetNode(PartitionNode): self.hive_partitioning = hive_partitioning def __str__(self) -> str: - return ( - super().__str__() - + f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}" - ) + return super().__str__() + f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}" @Node.task_factory def create_producer_task( @@ -1620,9 +1528,7 @@ def DataSetPartitionNode( -------- See unit test `test_load_partitioned_datasets` in `test/test_partition.py`. """ - assert not ( - partition_by_rows and data_partition_column - ), "partition_by_rows and data_partition_column cannot be set at the same time" + assert not (partition_by_rows and data_partition_column), "partition_by_rows and data_partition_column cannot be set at the same time" if data_partition_column is None: partition_node = EvenlyDistributedPartitionNode( ctx, @@ -1720,12 +1626,8 @@ class HashPartitionNode(PartitionNode): Specify if we should use dictionary encoding in general or only for some columns. See `use_dictionary` in https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html. """ - assert ( - not random_shuffle or not shuffle_only - ), f"random_shuffle and shuffle_only cannot be enabled at the same time" - assert ( - not shuffle_only or data_partition_column is not None - ), f"data_partition_column not specified for shuffle-only partitioning" + assert not random_shuffle or not shuffle_only, f"random_shuffle and shuffle_only cannot be enabled at the same time" + assert not shuffle_only or data_partition_column is not None, f"data_partition_column not specified for shuffle-only partitioning" assert data_partition_column is None or re.match( "[a-zA-Z0-9_]+", data_partition_column ), f"data_partition_column has invalid format: {data_partition_column}" @@ -1734,9 +1636,7 @@ class HashPartitionNode(PartitionNode): "duckdb", "arrow", ), f"unknown query engine type: {engine_type}" - data_partition_column = ( - data_partition_column or self.default_data_partition_column - ) + data_partition_column = data_partition_column or self.default_data_partition_column super().__init__( ctx, input_deps, @@ -1756,9 +1656,7 @@ class HashPartitionNode(PartitionNode): self.drop_partition_column = drop_partition_column self.use_parquet_writer = use_parquet_writer self.hive_partitioning = hive_partitioning and self.engine_type == "duckdb" - self.parquet_row_group_size = ( - parquet_row_group_size or self.default_row_group_size - ) + self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size self.parquet_dictionary_encoding = parquet_dictionary_encoding self.parquet_compression = parquet_compression self.parquet_compression_level = parquet_compression_level @@ -1929,22 +1827,15 @@ class ProjectionNode(Node): """ columns = columns or ["*"] generated_columns = generated_columns or [] - assert all( - col in GENERATED_COLUMNS for col in generated_columns - ), f"invalid values found in generated columns: {generated_columns}" - assert not ( - set(columns) & set(generated_columns) - ), f"columns {columns} and generated columns {generated_columns} share common columns" + assert all(col in GENERATED_COLUMNS for col in generated_columns), f"invalid values found in generated columns: {generated_columns}" + assert not (set(columns) & set(generated_columns)), f"columns {columns} and generated columns {generated_columns} share common columns" super().__init__(ctx, [input_dep]) self.columns = columns self.generated_columns = generated_columns self.union_by_name = union_by_name def __str__(self) -> str: - return ( - super().__str__() - + f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}" - ) + return super().__str__() + f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}" @Node.task_factory def create_task( @@ -2100,10 +1991,7 @@ class LogicalPlan(object): if node.id in visited: return lines + [" " * depth + " (omitted ...)"] visited.add(node.id) - lines += [ - " " * depth + f" | {name}: {stats}" - for name, stats in node.perf_stats.items() - ] + lines += [" " * depth + f" | {name}: {stats}" for name, stats in node.perf_stats.items()] for dep in node.input_deps: lines.extend(to_str(dep, depth + 1)) return lines diff --git a/smallpond/logical/optimizer.py b/smallpond/logical/optimizer.py index c84a0b7..1e9b905 100644 --- a/smallpond/logical/optimizer.py +++ b/smallpond/logical/optimizer.py @@ -32,9 +32,7 @@ class Optimizer(LogicalPlanVisitor[Node]): def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> Node: # fuse consecutive SqlEngineNodes - if len(node.input_deps) == 1 and isinstance( - child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode - ): + if len(node.input_deps) == 1 and isinstance(child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode): fused = copy.copy(node) fused.input_deps = child.input_deps fused.udfs = node.udfs + child.udfs @@ -52,8 +50,6 @@ class Optimizer(LogicalPlanVisitor[Node]): # node.sql_queries = ["select a, b from {0}"] # fused.sql_queries = ["select a, b from (select * from {0})"] # ``` - fused.sql_queries = child.sql_queries[:-1] + [ - query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries - ] + fused.sql_queries = child.sql_queries[:-1] + [query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries] return fused return self.generic_visit(node, depth) diff --git a/smallpond/logical/planner.py b/smallpond/logical/planner.py index e641ed4..921cf8a 100644 --- a/smallpond/logical/planner.py +++ b/smallpond/logical/planner.py @@ -14,30 +14,20 @@ class Planner(LogicalPlanVisitor[TaskGroup]): self.node_to_tasks: Dict[Node, TaskGroup] = {} @logger.catch(reraise=True, message="failed to build computation graph") - def create_exec_plan( - self, logical_plan: LogicalPlan, manifest_only_final_results=True - ) -> ExecutionPlan: + def create_exec_plan(self, logical_plan: LogicalPlan, manifest_only_final_results=True) -> ExecutionPlan: logical_plan = copy.deepcopy(logical_plan) # if --output_path is specified, copy files to the output path # otherwise, create manifest files only - sink_type = ( - "copy" if self.runtime_ctx.final_output_path is not None else "manifest" - ) - final_sink_type = ( - "copy" - if self.runtime_ctx.final_output_path is not None - else "manifest" if manifest_only_final_results else "link" - ) + sink_type = "copy" if self.runtime_ctx.final_output_path is not None else "manifest" + final_sink_type = "copy" if self.runtime_ctx.final_output_path is not None else "manifest" if manifest_only_final_results else "link" # create DataSinkNode for each named output node (same name share the same sink node) nodes_groupby_output_name: Dict[str, List[Node]] = defaultdict(list) for node in logical_plan.nodes.values(): if node.output_name is not None: if node.output_name in nodes_groupby_output_name: - warnings.warn( - f"{node} has duplicate output name: {node.output_name}" - ) + warnings.warn(f"{node} has duplicate output name: {node.output_name}") nodes_groupby_output_name[node.output_name].append(node) sink_nodes = {} # { output_name: DataSinkNode } for output_name, nodes in nodes_groupby_output_name.items(): @@ -45,9 +35,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]): self.runtime_ctx.final_output_path or self.runtime_ctx.output_root, output_name, ) - sink_nodes[output_name] = DataSinkNode( - logical_plan.ctx, tuple(nodes), output_path, type=sink_type - ) + sink_nodes[output_name] = DataSinkNode(logical_plan.ctx, tuple(nodes), output_path, type=sink_type) # create DataSinkNode for root node # XXX: special case optimization to avoid copying files twice @@ -63,9 +51,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]): ): sink_nodes["FinalResults"] = DataSinkNode( logical_plan.ctx, - tuple( - sink_nodes[node.output_name] for node in partition_node.input_deps - ), + tuple(sink_nodes[node.output_name] for node in partition_node.input_deps), output_path=os.path.join( self.runtime_ctx.final_output_path or self.runtime_ctx.output_root, "FinalResults", @@ -124,38 +110,24 @@ class Planner(LogicalPlanVisitor[TaskGroup]): return [node.create_task(self.runtime_ctx, [], [PartitionInfo()])] def visit_data_sink_node(self, node: DataSinkNode, depth: int) -> TaskGroup: - all_input_deps = [ - task for dep in node.input_deps for task in self.visit(dep, depth + 1) - ] + all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)] return [node.create_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])] def visit_root_node(self, node: RootNode, depth: int) -> TaskGroup: - all_input_deps = [ - task for dep in node.input_deps for task in self.visit(dep, depth + 1) - ] + all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)] return [RootTask(self.runtime_ctx, all_input_deps, [PartitionInfo()])] def visit_union_node(self, node: UnionNode, depth: int) -> TaskGroup: - all_input_deps = [ - task for dep in node.input_deps for task in self.visit(dep, depth + 1) - ] + all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)] unique_partition_dims = set(task.partition_dims for task in all_input_deps) - assert ( - len(unique_partition_dims) == 1 - ), f"cannot union partitions with different dimensions: {unique_partition_dims}" + assert len(unique_partition_dims) == 1, f"cannot union partitions with different dimensions: {unique_partition_dims}" return all_input_deps def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> TaskGroup: input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps] - assert ( - len(input_deps_taskgroups) == 1 - ), f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}" - unique_partition_dims = set( - task.partition_dims for task in input_deps_taskgroups[0] - ) - assert ( - len(unique_partition_dims) == 1 - ), f"cannot consolidate partitions with different dimensions: {unique_partition_dims}" + assert len(input_deps_taskgroups) == 1, f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}" + unique_partition_dims = set(task.partition_dims for task in input_deps_taskgroups[0]) + assert len(unique_partition_dims) == 1, f"cannot consolidate partitions with different dimensions: {unique_partition_dims}" existing_dimensions = set(unique_partition_dims.pop()) assert ( node.dimensions.intersection(existing_dimensions) == node.dimensions @@ -163,46 +135,30 @@ class Planner(LogicalPlanVisitor[TaskGroup]): # group tasks by partitions input_deps_groupby_partitions: Dict[Tuple, List[Task]] = defaultdict(list) for task in input_deps_taskgroups[0]: - partition_infos = tuple( - info - for info in task.partition_infos - if info.dimension in node.dimensions - ) + partition_infos = tuple(info for info in task.partition_infos if info.dimension in node.dimensions) input_deps_groupby_partitions[partition_infos].append(task) return [ - node.create_task(self.runtime_ctx, input_deps, partition_infos) - for partition_infos, input_deps in input_deps_groupby_partitions.items() + node.create_task(self.runtime_ctx, input_deps, partition_infos) for partition_infos, input_deps in input_deps_groupby_partitions.items() ] def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup: - all_input_deps = [ - task for dep in node.input_deps for task in self.visit(dep, depth + 1) - ] + all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)] unique_partition_dims = set(task.partition_dims for task in all_input_deps) - assert ( - len(unique_partition_dims) == 1 - ), f"cannot partition input_deps with different dimensions: {unique_partition_dims}" + assert len(unique_partition_dims) == 1, f"cannot partition input_deps with different dimensions: {unique_partition_dims}" if node.nested: assert ( node.dimension not in unique_partition_dims ), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}" assert ( - len(all_input_deps) * node.npartitions - <= node.max_card_of_producers_x_consumers + len(all_input_deps) * node.npartitions <= node.max_card_of_producers_x_consumers ), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}" - producer_tasks = [ - node.create_producer_task( - self.runtime_ctx, [task], task.partition_infos - ) - for task in all_input_deps - ] + producer_tasks = [node.create_producer_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps] return [ node.create_consumer_task( self.runtime_ctx, [producer], - list(producer.partition_infos) - + [PartitionInfo(partition_idx, node.npartitions, node.dimension)], + list(producer.partition_infos) + [PartitionInfo(partition_idx, node.npartitions, node.dimension)], ) for producer in producer_tasks for partition_idx in range(node.npartitions) @@ -212,16 +168,10 @@ class Planner(LogicalPlanVisitor[TaskGroup]): node.max_num_producer_tasks, math.ceil(node.max_card_of_producers_x_consumers / node.npartitions), ) - num_parallel_tasks = ( - 2 - * self.runtime_ctx.num_executors - * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit) - ) + num_parallel_tasks = 2 * self.runtime_ctx.num_executors * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit) num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks)) if len(all_input_deps) < num_producer_tasks: - merge_datasets_task = node.create_merge_task( - self.runtime_ctx, all_input_deps, [PartitionInfo()] - ) + merge_datasets_task = node.create_merge_task(self.runtime_ctx, all_input_deps, [PartitionInfo()]) split_dataset_tasks = [ node.create_split_task( self.runtime_ctx, @@ -237,15 +187,10 @@ class Planner(LogicalPlanVisitor[TaskGroup]): tasks, [PartitionInfo(partition_idx, num_producer_tasks)], ) - for partition_idx, tasks in enumerate( - split_into_rows(all_input_deps, num_producer_tasks) - ) + for partition_idx, tasks in enumerate(split_into_rows(all_input_deps, num_producer_tasks)) ] producer_tasks = [ - node.create_producer_task( - self.runtime_ctx, [split_dataset], split_dataset.partition_infos - ) - for split_dataset in split_dataset_tasks + node.create_producer_task(self.runtime_ctx, [split_dataset], split_dataset.partition_infos) for split_dataset in split_dataset_tasks ] return [ node.create_consumer_task( @@ -284,11 +229,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]): for main_input in input_deps_most_ndims: input_deps = [] for input_deps_dims, input_deps_map in input_deps_maps: - partition_infos = tuple( - info - for info in main_input.partition_infos - if info.dimension in input_deps_dims - ) + partition_infos = tuple(info for info in main_input.partition_infos if info.dimension in input_deps_dims) input_dep = input_deps_map.get(partition_infos, None) assert ( input_dep is not None @@ -299,50 +240,32 @@ class Planner(LogicalPlanVisitor[TaskGroup]): def visit_python_script_node(self, node: PythonScriptNode, depth: int) -> TaskGroup: return [ - node.create_task(self.runtime_ctx, input_deps, partition_infos) - for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth) ] def visit_arrow_compute_node(self, node: ArrowComputeNode, depth: int) -> TaskGroup: return [ - node.create_task(self.runtime_ctx, input_deps, partition_infos) - for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth) ] def visit_arrow_stream_node(self, node: ArrowStreamNode, depth: int) -> TaskGroup: return [ - node.create_task(self.runtime_ctx, input_deps, partition_infos) - for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth) ] def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> TaskGroup: return [ - node.create_task(self.runtime_ctx, input_deps, partition_infos) - for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth) ] def visit_projection_node(self, node: ProjectionNode, depth: int) -> TaskGroup: - assert ( - len(node.input_deps) == 1 - ), f"projection node only accepts one input node, but found: {node.input_deps}" - return [ - node.create_task(self.runtime_ctx, [task], task.partition_infos) - for task in self.visit(node.input_deps[0], depth + 1) - ] + assert len(node.input_deps) == 1, f"projection node only accepts one input node, but found: {node.input_deps}" + return [node.create_task(self.runtime_ctx, [task], task.partition_infos) for task in self.visit(node.input_deps[0], depth + 1)] def visit_limit_node(self, node: LimitNode, depth: int) -> TaskGroup: - assert ( - len(node.input_deps) == 1 - ), f"limit node only accepts one input node, but found: {node.input_deps}" + assert len(node.input_deps) == 1, f"limit node only accepts one input node, but found: {node.input_deps}" all_input_deps = self.visit(node.input_deps[0], depth + 1) - partial_limit_tasks = [ - node.create_task(self.runtime_ctx, [task], task.partition_infos) - for task in all_input_deps - ] - merge_task = node.create_merge_task( - self.runtime_ctx, partial_limit_tasks, [PartitionInfo()] - ) - global_limit_task = node.create_task( - self.runtime_ctx, [merge_task], merge_task.partition_infos - ) + partial_limit_tasks = [node.create_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps] + merge_task = node.create_merge_task(self.runtime_ctx, partial_limit_tasks, [PartitionInfo()]) + global_limit_task = node.create_task(self.runtime_ctx, [merge_task], merge_task.partition_infos) return [global_limit_task] diff --git a/smallpond/logical/udf.py b/smallpond/logical/udf.py index f8a8b7a..e45ae4b 100644 --- a/smallpond/logical/udf.py +++ b/smallpond/logical/udf.py @@ -264,6 +264,4 @@ def udf( See `Context.create_function` for more details. """ - return lambda func: UserDefinedFunction( - name or func.__name__, func, params, return_type, use_arrow_type - ) + return lambda func: UserDefinedFunction(name or func.__name__, func, params, return_type, use_arrow_type) diff --git a/smallpond/session.py b/smallpond/session.py index 171be42..dd9f70f 100644 --- a/smallpond/session.py +++ b/smallpond/session.py @@ -57,9 +57,7 @@ class SessionBase: logger.info(f"session config: {self.config}") def setup_worker(): - runtime_ctx._init_logs( - exec_id=socket.gethostname(), capture_stdout_stderr=True - ) + runtime_ctx._init_logs(exec_id=socket.gethostname(), capture_stdout_stderr=True) if self.config.ray_address is None: # find the memory allocator @@ -72,9 +70,7 @@ class SessionBase: malloc_path = shutil.which("libmimalloc.so.2.1") assert malloc_path is not None, "mimalloc is not installed" else: - raise ValueError( - f"unsupported memory allocator: {self.config.memory_allocator}" - ) + raise ValueError(f"unsupported memory allocator: {self.config.memory_allocator}") memory_purge_delay = 10000 # start ray head node @@ -84,11 +80,7 @@ class SessionBase: # start a new local cluster address="local", # disable local CPU resource if not running on localhost - num_cpus=( - 0 - if self.config.num_executors > 0 - else self._runtime_ctx.usable_cpu_count - ), + num_cpus=(0 if self.config.num_executors > 0 else self._runtime_ctx.usable_cpu_count), # set the memory limit to the available memory size _memory=self._runtime_ctx.usable_memory_size, # setup logging for workers @@ -142,9 +134,7 @@ class SessionBase: # spawn a thread to periodically dump metrics self._stop_event = threading.Event() - self._dump_thread = threading.Thread( - name="dump_thread", target=self._dump_periodically, daemon=True - ) + self._dump_thread = threading.Thread(name="dump_thread", target=self._dump_periodically, daemon=True) self._dump_thread.start() def shutdown(self): @@ -184,11 +174,7 @@ class SessionBase: extra_opts=dict( tags=["smallpond", "scheduler", smallpond.__version__], ), - envs={ - k: v - for k, v in os.environ.items() - if k.startswith("SP_") and k != "SP_SPAWN" - }, + envs={k: v for k, v in os.environ.items() if k.startswith("SP_") and k != "SP_SPAWN"}, ) def _start_prometheus(self) -> Optional[subprocess.Popen]: @@ -233,8 +219,7 @@ class SessionBase: stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"), env={ "GF_SERVER_HTTP_PORT": "8122", # redirect to an available port - "GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST") - or "http://localhost:8122", + "GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST") or "http://localhost:8122", "GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data", }, ) @@ -309,12 +294,8 @@ class SessionBase: self.dump_graph() self.dump_timeline() num_total_tasks, num_finished_tasks = self._summarize_task() - percent = ( - num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0 - ) - logger.info( - f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)" - ) + percent = num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0 + logger.info(f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)") @dataclass @@ -360,20 +341,12 @@ class Config: platform = get_platform(get_env("PLATFORM") or platform) job_id = get_env("JOBID") or job_id or platform.default_job_id() - job_time = ( - get_env("JOB_TIME", datetime.fromisoformat) - or job_time - or platform.default_job_time() - ) + job_time = get_env("JOB_TIME", datetime.fromisoformat) or job_time or platform.default_job_time() data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root() num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0 ray_address = get_env("RAY_ADDRESS") or ray_address bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node - memory_allocator = ( - get_env("MEMORY_ALLOCATOR") - or memory_allocator - or platform.default_memory_allocator() - ) + memory_allocator = get_env("MEMORY_ALLOCATOR") or memory_allocator or platform.default_memory_allocator() config = Config( job_id=job_id, diff --git a/smallpond/utility.py b/smallpond/utility.py index 80c37f0..4fcd9e3 100644 --- a/smallpond/utility.py +++ b/smallpond/utility.py @@ -24,9 +24,7 @@ def overall_stats( ): from smallpond.logical.node import DataSetPartitionNode, DataSinkNode, SqlEngineNode - n = SqlEngineNode( - ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit - ) + n = SqlEngineNode(ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit) p = DataSetPartitionNode(ctx, (n,), npartitions=1) n2 = SqlEngineNode( ctx, @@ -59,9 +57,7 @@ def execute_command(cmd: str, env: Dict[str, str] = None, shell=False): raise subprocess.CalledProcessError(return_code, cmd) -def cprofile_to_string( - perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20 -): +def cprofile_to_string(perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20): perf_profile.disable() pstats_output = io.StringIO() profile_stats = pstats.Stats(perf_profile, stream=pstats_output) @@ -111,9 +107,7 @@ class ConcurrentIter(object): """ def __init__(self, iterable: Iterable, max_buffer_size=1) -> None: - assert isinstance( - iterable, Iterable - ), f"expect an iterable but found: {repr(iterable)}" + assert isinstance(iterable, Iterable), f"expect an iterable but found: {repr(iterable)}" self.__iterable = iterable self.__queue = queue.Queue(max_buffer_size) self.__last = object() @@ -194,6 +188,4 @@ class InterceptHandler(logging.Handler): frame = frame.f_back depth += 1 - logger.opt(depth=depth, exception=record.exc_info).log( - level, record.getMessage() - ) + logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) diff --git a/smallpond/worker.py b/smallpond/worker.py index 891cc6e..f2eb7b9 100644 --- a/smallpond/worker.py +++ b/smallpond/worker.py @@ -14,9 +14,7 @@ if __name__ == "__main__": required=True, help="The address of the Ray cluster to connect to", ) - parser.add_argument( - "--log_dir", required=True, help="The directory where logs will be stored" - ) + parser.add_argument("--log_dir", required=True, help="The directory where logs will be stored") parser.add_argument( "--bind_numa_node", action="store_true", diff --git a/tests/datagen.py b/tests/datagen.py index 1750137..56c3686 100644 --- a/tests/datagen.py +++ b/tests/datagen.py @@ -14,19 +14,13 @@ from filelock import FileLock def generate_url_and_domain() -> Tuple[str, str]: - domain_part = "".join( - random.choices(string.ascii_lowercase, k=random.randint(5, 15)) - ) + domain_part = "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 15))) tld = random.choice(["com", "net", "org", "cn", "edu", "gov", "co", "io"]) domain = f"www.{domain_part}.{tld}" path_segments = [] for _ in range(random.randint(1, 3)): - segment = "".join( - random.choices( - string.ascii_lowercase + string.digits, k=random.randint(3, 10) - ) - ) + segment = "".join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(3, 10))) path_segments.append(segment) path = "/" + "/".join(path_segments) @@ -42,26 +36,18 @@ def generate_random_date() -> str: start = datetime(2023, 1, 1, tzinfo=timezone.utc) end = datetime(2023, 12, 31, tzinfo=timezone.utc) delta = end - start - random_date = start + timedelta( - seconds=random.randint(0, int(delta.total_seconds())) - ) + random_date = start + timedelta(seconds=random.randint(0, int(delta.total_seconds()))) return random_date.strftime("%Y-%m-%dT%H:%M:%SZ") def generate_content() -> bytes: - target_length = ( - random.randint(1000, 100000) - if random.random() < 0.8 - else random.randint(100000, 1000000) - ) + target_length = random.randint(1000, 100000) if random.random() < 0.8 else random.randint(100000, 1000000) before = b"Random Page" after = b"" total_before_after = len(before) + len(after) fill_length = max(target_length - total_before_after, 0) - filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[ - :fill_length - ] + filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[:fill_length] return before + filler + after @@ -103,9 +89,7 @@ def generate_random_string(length: int) -> str: def generate_random_url() -> str: """Generate a random URL""" path = generate_random_string(random.randint(10, 20)) - return ( - f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html" - ) + return f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html" def generate_random_data() -> str: @@ -138,9 +122,7 @@ def generate_url_parquet_files(output_dir: str, num_files: int = 10): ) -def generate_url_tsv_files( - output_dir: str, num_files: int = 10, lines_per_file: int = 100 -): +def generate_url_tsv_files(output_dir: str, num_files: int = 10, lines_per_file: int = 100): """Generate multiple files, each containing a specified number of random data lines""" os.makedirs(output_dir, exist_ok=True) for i in range(num_files): @@ -166,16 +148,12 @@ def generate_data(path: str = "tests/data"): with FileLock(path + "/data.lock"): print("Generating data...") if not os.path.exists(path + "/mock_urls"): - generate_url_tsv_files( - output_dir=path + "/mock_urls", num_files=10, lines_per_file=100 - ) + generate_url_tsv_files(output_dir=path + "/mock_urls", num_files=10, lines_per_file=100) generate_url_parquet_files(output_dir=path + "/mock_urls", num_files=10) if not os.path.exists(path + "/arrow"): generate_arrow_files(output_dir=path + "/arrow", num_files=10) if not os.path.exists(path + "/large_array"): - concat_arrow_files( - input_dir=path + "/arrow", output_dir=path + "/large_array" - ) + concat_arrow_files(input_dir=path + "/arrow", output_dir=path + "/large_array") if not os.path.exists(path + "/long_path_list.txt"): generate_long_path_list(path=path + "/long_path_list.txt") except Exception as e: diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 4e5a13f..9aa90a4 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -37,10 +37,7 @@ class TestArrow(TestFabric, unittest.TestCase): with self.subTest(dataset_path=dataset_path): metadata = parquet.read_metadata(dataset_path) file_num_rows = metadata.num_rows - data_size = sum( - metadata.row_group(i).total_byte_size - for i in range(metadata.num_row_groups) - ) + data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups)) row_range = RowRange( path=dataset_path, begin=100, @@ -48,9 +45,7 @@ class TestArrow(TestFabric, unittest.TestCase): data_size=data_size, file_num_rows=file_num_rows, ) - expected = self._load_parquet_files([dataset_path]).slice( - offset=100, length=100 - ) + expected = self._load_parquet_files([dataset_path]).slice(offset=100, length=100) actual = load_from_parquet_files([row_range]) self._compare_arrow_tables(expected, actual) @@ -62,21 +57,15 @@ class TestArrow(TestFabric, unittest.TestCase): with self.subTest(dataset_path=dataset_path): parquet_files = glob.glob(dataset_path) expected = self._load_parquet_files(parquet_files) - with tempfile.TemporaryDirectory( - dir=self.output_root_abspath - ) as output_dir: + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: ok = dump_to_parquet_files(expected, output_dir) self.assertTrue(ok) - actual = self._load_parquet_files( - glob.glob(f"{output_dir}/*.parquet") - ) + actual = self._load_parquet_files(glob.glob(f"{output_dir}/*.parquet")) self._compare_arrow_tables(expected, actual) def test_dump_load_empty_table(self): # create empty table - empty_table = self._load_parquet_files( - ["tests/data/arrow/data0.parquet"] - ).slice(length=0) + empty_table = self._load_parquet_files(["tests/data/arrow/data0.parquet"]).slice(length=0) self.assertEqual(empty_table.num_rows, 0) # dump empty table with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: @@ -94,9 +83,7 @@ class TestArrow(TestFabric, unittest.TestCase): ): with self.subTest(dataset_path=dataset_path): parquet_files = glob.glob(dataset_path) - expected_num_rows = sum( - parquet.read_metadata(file).num_rows for file in parquet_files - ) + expected_num_rows = sum(parquet.read_metadata(file).num_rows for file in parquet_files) with build_batch_reader_from_files( parquet_files, batch_size=expected_num_rows, @@ -104,9 +91,7 @@ class TestArrow(TestFabric, unittest.TestCase): ) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter: total_num_rows = 0 for batch in concurrent_iter: - print( - f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}" - ) + print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}") self.assertLessEqual(batch.num_rows, expected_num_rows) total_num_rows += batch.num_rows self.assertEqual(total_num_rows, expected_num_rows) @@ -121,9 +106,7 @@ class TestArrow(TestFabric, unittest.TestCase): table = self._load_parquet_files(parquet_files) total_num_rows = 0 for batch in table.to_batches(max_chunksize=table.num_rows): - print( - f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}" - ) + print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}") self.assertLessEqual(batch.num_rows, table.num_rows) total_num_rows += batch.num_rows self.assertEqual(total_num_rows, table.num_rows) @@ -135,26 +118,14 @@ class TestArrow(TestFabric, unittest.TestCase): print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}") with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: - self.assertTrue( - dump_to_parquet_files( - table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2 - ) - ) - parquet_files = glob.glob( - os.path.join(output_dir, "arrow_schema_metadata*.parquet") - ) - loaded_table = load_from_parquet_files( - parquet_files, table.column_names[:1] - ) + self.assertTrue(dump_to_parquet_files(table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2)) + parquet_files = glob.glob(os.path.join(output_dir, "arrow_schema_metadata*.parquet")) + loaded_table = load_from_parquet_files(parquet_files, table.column_names[:1]) print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}") - self.assertEqual( - table_with_meta.schema.metadata, loaded_table.schema.metadata - ) + self.assertEqual(table_with_meta.schema.metadata, loaded_table.schema.metadata) with parquet.ParquetFile(parquet_files[0]) as file: print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}") - self.assertEqual( - table_with_meta.schema.metadata, file.schema_arrow.metadata - ) + self.assertEqual(table_with_meta.schema.metadata, file.schema_arrow.metadata) def test_load_mixed_string_types(self): parquet_paths = glob.glob("tests/data/arrow/*.parquet") @@ -166,9 +137,7 @@ class TestArrow(TestFabric, unittest.TestCase): loaded_table = load_from_parquet_files(parquet_paths) self.assertEqual(table.num_rows * 2, loaded_table.num_rows) batch_reader = build_batch_reader_from_files(parquet_paths) - self.assertEqual( - table.num_rows * 2, sum(batch.num_rows for batch in batch_reader) - ) + self.assertEqual(table.num_rows * 2, sum(batch.num_rows for batch in batch_reader)) @logger.catch(reraise=True, message="failed to load parquet files") def _load_from_parquet_files_with_log(self, paths, columns): diff --git a/tests/test_bench.py b/tests/test_bench.py index f0fe872..63728c4 100644 --- a/tests/test_bench.py +++ b/tests/test_bench.py @@ -45,9 +45,7 @@ class TestBench(TestFabric, unittest.TestCase): num_sort_partitions = 1 << 3 for shuffle_engine in ("duckdb", "arrow"): for sort_engine in ("duckdb", "arrow", "polars"): - with self.subTest( - shuffle_engine=shuffle_engine, sort_engine=sort_engine - ): + with self.subTest(shuffle_engine=shuffle_engine, sort_engine=sort_engine): ctx = Context() random_records = generate_random_records( ctx, diff --git a/tests/test_common.py b/tests/test_common.py index d0edbfd..4f0dccd 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -34,9 +34,7 @@ class TestCommon(TestFabric, unittest.TestCase): npartitions = data.draw(st.integers(1, 2 * nelements)) items = list(range(nelements)) computed = split_into_rows(items, npartitions) - expected = [ - get_nth_partition(items, n, npartitions) for n in range(npartitions) - ] + expected = [get_nth_partition(items, n, npartitions) for n in range(npartitions)] self.assertEqual(expected, computed) @given(st.data()) diff --git a/tests/test_dataframe.py b/tests/test_dataframe.py index 49aff23..6d9b085 100644 --- a/tests/test_dataframe.py +++ b/tests/test_dataframe.py @@ -70,9 +70,7 @@ def test_flat_map(sp: Session): # user need to specify the schema if can not be inferred from the mapping values df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())])) - assert df3.to_arrow() == pa.table( - {"c": pa.array([None, None, None], type=pa.int64())} - ) + assert df3.to_arrow() == pa.table({"c": pa.array([None, None, None], type=pa.int64())}) def test_map_batches(sp: Session): @@ -99,10 +97,7 @@ def test_random_shuffle(sp: Session): assert sorted(shuffled) == list(range(1000)) def count_inversions(arr: List[int]) -> int: - return sum( - sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) - for i in range(len(arr)) - ) + return sum(sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) for i in range(len(arr))) # check the shuffle is random enough # the expected number of inversions is n*(n-1)/4 = 249750 @@ -158,9 +153,7 @@ def test_partial_sql(sp: Session): # join df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]})) df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]})) - joined = sp.partial_sql( - "select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2 - ) + joined = sp.partial_sql("select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2) assert joined.to_arrow() == pa.table( {"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]}, schema=pa.schema( @@ -193,10 +186,7 @@ def test_unpicklable_task_exception(sp: Session): df.map(lambda x: logger.info("use outside logger")).to_arrow() except Exception as ex: assert "Can't pickle task" in str(ex) - assert ( - "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." - in str(ex) - ) + assert "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." in str(ex) else: assert False, "expected exception" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b9c7a05..974673e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -30,9 +30,7 @@ class TestDataSet(TestFabric, unittest.TestCase): dataset = ParquetDataSet([os.path.join(self.output_root_abspath, "*.parquet")]) self.assertEqual(num_urls, dataset.num_rows) - def _generate_parquet_dataset( - self, output_path, npartitions, num_rows, row_group_size - ): + def _generate_parquet_dataset(self, output_path, npartitions, num_rows, row_group_size): duckdb.sql( f"""copy ( select range as i, range % {npartitions} as partition from range(0, {num_rows}) ) @@ -41,9 +39,7 @@ class TestDataSet(TestFabric, unittest.TestCase): ) return ParquetDataSet([f"{output_path}/**/*.parquet"]) - def _check_partition_datasets( - self, orig_dataset: ParquetDataSet, partition_func, npartition - ): + def _check_partition_datasets(self, orig_dataset: ParquetDataSet, partition_func, npartition): # build partitioned datasets partitioned_datasets = partition_func(npartition) self.assertEqual(npartition, len(partitioned_datasets)) @@ -52,9 +48,7 @@ class TestDataSet(TestFabric, unittest.TestCase): sum(dataset.num_rows for dataset in partitioned_datasets), ) # load as arrow table - loaded_table = arrow.concat_tables( - [dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets] - ) + loaded_table = arrow.concat_tables([dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets]) self.assertEqual(orig_dataset.num_rows, loaded_table.num_rows) # compare arrow tables orig_table = orig_dataset.to_arrow_table(max_workers=1) @@ -74,9 +68,7 @@ class TestDataSet(TestFabric, unittest.TestCase): def test_partition_by_files(self): output_path = os.path.join(self.output_root_abspath, "test_partition_by_files") - orig_dataset = self._generate_parquet_dataset( - output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000 - ) + orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000) num_files = len(orig_dataset.resolved_paths) for npartition in range(1, num_files + 1): for random_shuffle in (False, True): @@ -84,17 +76,13 @@ class TestDataSet(TestFabric, unittest.TestCase): orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir) self._check_partition_datasets( orig_dataset, - lambda n: orig_dataset.partition_by_files( - n, random_shuffle=random_shuffle - ), + lambda n: orig_dataset.partition_by_files(n, random_shuffle=random_shuffle), npartition, ) def test_partition_by_rows(self): output_path = os.path.join(self.output_root_abspath, "test_partition_by_rows") - orig_dataset = self._generate_parquet_dataset( - output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000 - ) + orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000) num_files = len(orig_dataset.resolved_paths) for npartition in range(1, 2 * num_files + 1): for random_shuffle in (False, True): @@ -102,9 +90,7 @@ class TestDataSet(TestFabric, unittest.TestCase): orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir) self._check_partition_datasets( orig_dataset, - lambda n: orig_dataset.partition_by_rows( - n, random_shuffle=random_shuffle - ), + lambda n: orig_dataset.partition_by_rows(n, random_shuffle=random_shuffle), npartition, ) @@ -116,9 +102,7 @@ class TestDataSet(TestFabric, unittest.TestCase): self.assertEqual(len(dataset.resolved_paths), len(filenames)) def test_paths_with_char_ranges(self): - dataset_with_char_ranges = ParquetDataSet( - ["tests/data/arrow/data[0-9].parquet"] - ) + dataset_with_char_ranges = ParquetDataSet(["tests/data/arrow/data[0-9].parquet"]) dataset_with_wildcards = ParquetDataSet(["tests/data/arrow/*.parquet"]) self.assertEqual( len(dataset_with_char_ranges.resolved_paths), @@ -126,9 +110,7 @@ class TestDataSet(TestFabric, unittest.TestCase): ) def test_to_arrow_table_batch_reader(self): - memdb = duckdb.connect( - database=":memory:", config={"arrow_large_buffer_size": "true"} - ) + memdb = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"}) for dataset_path in ( "tests/data/arrow/*.parquet", "tests/data/large_array/*.parquet", @@ -137,24 +119,14 @@ class TestDataSet(TestFabric, unittest.TestCase): print(f"dataset_path: {dataset_path}, conn: {conn}") with self.subTest(dataset_path=dataset_path, conn=conn): dataset = ParquetDataSet([dataset_path]) - to_batches = dataset.to_arrow_table( - max_workers=1, conn=conn - ).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2) - batch_reader = dataset.to_batch_reader( - batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn - ) - with ConcurrentIter( - batch_reader, max_buffer_size=2 - ) as batch_reader: + to_batches = dataset.to_arrow_table(max_workers=1, conn=conn).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2) + batch_reader = dataset.to_batch_reader(batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn) + with ConcurrentIter(batch_reader, max_buffer_size=2) as batch_reader: for batch_iter in (to_batches, batch_reader): total_num_rows = 0 for batch in batch_iter: - print( - f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}" - ) - self.assertLessEqual( - batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2 - ) + print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}") + self.assertLessEqual(batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2) total_num_rows += batch.num_rows print(f"{dataset_path}: total_num_rows {total_num_rows}") self.assertEqual(total_num_rows, dataset.num_rows) @@ -167,8 +139,6 @@ def test_arrow_reader(benchmark, reader: str, dataset_path: str): dataset = ParquetDataSet([dataset_path]) conn = None if reader == "duckdb": - conn = duckdb.connect( - database=":memory:", config={"arrow_large_buffer_size": "true"} - ) + conn = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"}) benchmark(dataset.to_arrow_table, conn=conn) # result: arrow reader is 4x faster than duckdb reader in small dataset, 1.4x faster in large dataset diff --git a/tests/test_deltalake.py b/tests/test_deltalake.py index c6c2d4b..e647911 100644 --- a/tests/test_deltalake.py +++ b/tests/test_deltalake.py @@ -7,9 +7,7 @@ from smallpond.io.arrow import cast_columns_to_large_string from tests.test_fabric import TestFabric -@unittest.skipUnless( - importlib.util.find_spec("deltalake") is not None, "cannot find deltalake" -) +@unittest.skipUnless(importlib.util.find_spec("deltalake") is not None, "cannot find deltalake") class TestDeltaLake(TestFabric, unittest.TestCase): def test_read_write_deltalake(self): from deltalake import DeltaTable, write_deltalake @@ -20,9 +18,7 @@ class TestDeltaLake(TestFabric, unittest.TestCase): ): parquet_files = glob.glob(dataset_path) expected = self._load_parquet_files(parquet_files) - with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory( - dir=self.output_root_abspath - ) as output_dir: + with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: write_deltalake(output_dir, expected, large_dtypes=True) dt = DeltaTable(output_dir) self._compare_arrow_tables(expected, dt.to_pyarrow_table()) @@ -35,12 +31,8 @@ class TestDeltaLake(TestFabric, unittest.TestCase): "tests/data/large_array/*.parquet", ): parquet_files = glob.glob(dataset_path) - with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory( - dir=self.output_root_abspath - ) as output_dir: - table = cast_columns_to_large_string( - self._load_parquet_files(parquet_files) - ) + with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + table = cast_columns_to_large_string(self._load_parquet_files(parquet_files)) write_deltalake(output_dir, table, large_dtypes=True, mode="overwrite") write_deltalake(output_dir, table, large_dtypes=False, mode="append") loaded_table = DeltaTable(output_dir).to_pyarrow_table() diff --git a/tests/test_execution.py b/tests/test_execution.py index b503b3b..92f9243 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -69,9 +69,7 @@ class OutputMsgPythonTask(PythonScriptTask): input_datasets: List[DataSet], output_path: str, ) -> bool: - logger.info( - f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}" - ) + logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}") self.inject_fault() return True @@ -105,9 +103,7 @@ class CopyInputArrowNode(ArrowComputeNode): super().__init__(ctx, input_deps) self.msg = msg - def process( - self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] - ) -> arrow.Table: + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: return copy_input_arrow(runtime_ctx, input_tables, self.msg) @@ -117,24 +113,18 @@ class CopyInputStreamNode(ArrowStreamNode): super().__init__(ctx, input_deps) self.msg = msg - def process( - self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] - ) -> Iterable[arrow.Table]: + def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: return copy_input_stream(runtime_ctx, input_readers, self.msg) -def copy_input_arrow( - runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str -) -> arrow.Table: +def copy_input_arrow(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str) -> arrow.Table: logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}") time.sleep(runtime_ctx.secs_executor_probe_interval) runtime_ctx.task.inject_fault() return input_tables[0] -def copy_input_stream( - runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str -) -> Iterable[arrow.Table]: +def copy_input_stream(runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str) -> Iterable[arrow.Table]: for index, batch in enumerate(input_readers[0]): logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}") time.sleep(runtime_ctx.secs_executor_probe_interval) @@ -146,62 +136,44 @@ def copy_input_stream( runtime_ctx.task.inject_fault() -def copy_input_batch( - runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str -) -> arrow.Table: +def copy_input_batch(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str) -> arrow.Table: logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}") time.sleep(runtime_ctx.secs_executor_probe_interval) runtime_ctx.task.inject_fault() return input_batches[0] -def copy_input_data_frame( - runtime_ctx: RuntimeContext, input_dfs: List[DataFrame] -) -> DataFrame: +def copy_input_data_frame(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame: runtime_ctx.task.inject_fault() return input_dfs[0] -def copy_input_data_frame_batch( - runtime_ctx: RuntimeContext, input_dfs: List[DataFrame] -) -> DataFrame: +def copy_input_data_frame_batch(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame: runtime_ctx.task.inject_fault() return input_dfs[0] -def merge_input_tables( - runtime_ctx: RuntimeContext, input_batches: List[arrow.Table] -) -> arrow.Table: +def merge_input_tables(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]) -> arrow.Table: runtime_ctx.task.inject_fault() output = arrow.concat_tables(input_batches) - logger.info( - f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}" - ) + logger.info(f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}") return output -def merge_input_data_frames( - runtime_ctx: RuntimeContext, input_dfs: List[DataFrame] -) -> DataFrame: +def merge_input_data_frames(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame: runtime_ctx.task.inject_fault() output = pandas.concat(input_dfs) - logger.info( - f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}" - ) + logger.info(f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}") return output -def parse_url( - runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] -) -> arrow.Table: +def parse_url(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: urls = input_tables[0].columns[0] hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls] return input_tables[0].append_column("host", arrow.array(hosts)) -def nonzero_exit_code( - runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str -) -> bool: +def nonzero_exit_code(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool: import sys if runtime_ctx.task._memory_boost == 1: @@ -210,9 +182,7 @@ def nonzero_exit_code( # create an empty file with a fixed name -def empty_file( - runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str -) -> bool: +def empty_file(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool: import os with open(os.path.join(output_path, "file"), "w") as fout: @@ -231,9 +201,7 @@ def split_url(urls: arrow.array) -> arrow.array: return arrow.array(url_parts, type=arrow.list_(arrow.string())) -def choose_random_urls( - runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5 -) -> arrow.Table: +def choose_random_urls(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5) -> arrow.Table: # get the current running task runtime_task = runtime_ctx.task # access task-specific attributes @@ -255,16 +223,12 @@ class TestExecution(TestFabric, unittest.TestCase): def test_arrow_task(self): for use_duckdb_reader in (False, True): with self.subTest(use_duckdb_reader=use_duckdb_reader): - with tempfile.TemporaryDirectory( - dir=self.output_root_abspath - ) as output_dir: + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_table = dataset.to_arrow_table() data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=7 - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7) if use_duckdb_reader: data_partitions = ProjectionNode( ctx, @@ -274,9 +238,7 @@ class TestExecution(TestFabric, unittest.TestCase): arrow_compute = ArrowComputeNode( ctx, (data_partitions,), - process_func=functools.partial( - copy_input_arrow, msg="arrow compute" - ), + process_func=functools.partial(copy_input_arrow, msg="arrow compute"), use_duckdb_reader=use_duckdb_reader, output_name="arrow_compute", output_path=output_dir, @@ -285,9 +247,7 @@ class TestExecution(TestFabric, unittest.TestCase): arrow_stream = ArrowStreamNode( ctx, (data_partitions,), - process_func=functools.partial( - copy_input_stream, msg="arrow stream" - ), + process_func=functools.partial(copy_input_stream, msg="arrow stream"), streaming_batch_size=10, secs_checkpoint_interval=0.5, use_duckdb_reader=use_duckdb_reader, @@ -298,9 +258,7 @@ class TestExecution(TestFabric, unittest.TestCase): arrow_batch = ArrowBatchNode( ctx, (data_partitions,), - process_func=functools.partial( - copy_input_batch, msg="arrow batch" - ), + process_func=functools.partial(copy_input_batch, msg="arrow batch"), streaming_batch_size=10, secs_checkpoint_interval=0.5, use_duckdb_reader=use_duckdb_reader, @@ -314,12 +272,8 @@ class TestExecution(TestFabric, unittest.TestCase): output_path=output_dir, ) plan = LogicalPlan(ctx, data_sink) - exec_plan = self.execute_plan( - plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5 - ) - self.assertTrue( - all(map(os.path.exists, exec_plan.final_output.resolved_paths)) - ) + exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5) + self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths))) arrow_compute_output = ParquetDataSet( [os.path.join(output_dir, "arrow_compute", "**/*.parquet")], recursive=True, @@ -334,21 +288,15 @@ class TestExecution(TestFabric, unittest.TestCase): ) self._compare_arrow_tables( data_table, - arrow_compute_output.to_arrow_table().select( - data_table.column_names - ), + arrow_compute_output.to_arrow_table().select(data_table.column_names), ) self._compare_arrow_tables( data_table, - arrow_stream_output.to_arrow_table().select( - data_table.column_names - ), + arrow_stream_output.to_arrow_table().select(data_table.column_names), ) self._compare_arrow_tables( data_table, - arrow_batch_output.to_arrow_table().select( - data_table.column_names - ), + arrow_batch_output.to_arrow_table().select(data_table.column_names), ) def test_pandas_task(self): @@ -376,16 +324,10 @@ class TestExecution(TestFabric, unittest.TestCase): output_path=output_dir, cpu_limit=2, ) - data_sink = DataSinkNode( - ctx, (pandas_compute, pandas_batch), output_path=output_dir - ) + data_sink = DataSinkNode(ctx, (pandas_compute, pandas_batch), output_path=output_dir) plan = LogicalPlan(ctx, data_sink) - exec_plan = self.execute_plan( - plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5 - ) - self.assertTrue( - all(map(os.path.exists, exec_plan.final_output.resolved_paths)) - ) + exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5) + self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths))) pandas_compute_output = ParquetDataSet( [os.path.join(output_dir, "pandas_compute", "**/*.parquet")], recursive=True, @@ -394,21 +336,15 @@ class TestExecution(TestFabric, unittest.TestCase): [os.path.join(output_dir, "pandas_batch", "**/*.parquet")], recursive=True, ) - self._compare_arrow_tables( - data_table, pandas_compute_output.to_arrow_table() - ) + self._compare_arrow_tables(data_table, pandas_compute_output.to_arrow_table()) self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table()) def test_variable_length_input_datasets(self): ctx = Context() small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10) - small_partitions = DataSetPartitionNode( - ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7 - ) - large_partitions = DataSetPartitionNode( - ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7 - ) + small_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7) + large_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7) arrow_batch = ArrowBatchNode( ctx, (small_partitions, large_partitions), @@ -428,9 +364,7 @@ class TestExecution(TestFabric, unittest.TestCase): cpu_limit=2, ) plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch))) - exec_plan = self.execute_plan( - plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5 - ) + exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5) self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths))) arrow_batch_output = ParquetDataSet( [os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")], @@ -440,9 +374,7 @@ class TestExecution(TestFabric, unittest.TestCase): [os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")], recursive=True, ) - self.assertEqual( - small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows - ) + self.assertEqual(small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows) self.assertEqual( small_dataset.num_rows + large_dataset.num_rows, pandas_batch_output.num_rows, @@ -453,9 +385,7 @@ class TestExecution(TestFabric, unittest.TestCase): # select columns when defining dataset dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"]) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=3, partition_by_rows=True - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3, partition_by_rows=True) # projection as input of arrow node generated_columns = ["filename", "file_row_number"] urls_with_host = ArrowComputeNode( @@ -480,9 +410,7 @@ class TestExecution(TestFabric, unittest.TestCase): # unify different schemas merged_diff_schemas = ProjectionNode( ctx, - DataSetPartitionNode( - ctx, (distinct_urls_with_host, urls_with_host), npartitions=1 - ), + DataSetPartitionNode(ctx, (distinct_urls_with_host, urls_with_host), npartitions=1), union_by_name=True, ) host_partitions = HashPartitionNode( @@ -513,9 +441,7 @@ class TestExecution(TestFabric, unittest.TestCase): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=dataset.num_files - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files) ctx.create_function( "split_url", split_url, @@ -537,9 +463,7 @@ class TestExecution(TestFabric, unittest.TestCase): npartitions = 1000 dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions) data_files = DataSourceNode(ctx, dataset) - data_partitions = EvenlyDistributedPartitionNode( - ctx, (data_files,), npartitions=npartitions - ) + data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions) output_msg = OutputMsgPythonNode(ctx, (data_partitions,)) plan = LogicalPlan(ctx, output_msg) self.execute_plan( @@ -552,13 +476,9 @@ class TestExecution(TestFabric, unittest.TestCase): def test_many_producers_and_partitions(self): ctx = Context() npartitions = 10000 - dataset = ParquetDataSet( - ["tests/data/mock_urls/*.parquet"] * (npartitions * 10) - ) + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * (npartitions * 10)) data_files = DataSourceNode(ctx, dataset) - data_partitions = EvenlyDistributedPartitionNode( - ctx, (data_files,), npartitions=npartitions, cpu_limit=1 - ) + data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions, cpu_limit=1) data_partitions.max_num_producer_tasks = 20 output_msg = OutputMsgPythonNode(ctx, (data_partitions,)) plan = LogicalPlan(ctx, output_msg) @@ -573,12 +493,8 @@ class TestExecution(TestFabric, unittest.TestCase): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=dataset.num_files - ) - output_msg = OutputMsgPythonNode( - ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5 - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files) + output_msg = OutputMsgPythonNode(ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5) plan = LogicalPlan(ctx, output_msg) runtime_ctx = RuntimeContext( JobId.new(), @@ -596,9 +512,7 @@ class TestExecution(TestFabric, unittest.TestCase): data_files = DataSourceNode(ctx, dataset) copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello") copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello") - output_msg = OutputMsgPythonNode2( - ctx, (copy_input_arrow_node, copy_input_stream_node), "hello" - ) + output_msg = OutputMsgPythonNode2(ctx, (copy_input_arrow_node, copy_input_stream_node), "hello") plan = LogicalPlan(ctx, output_msg) self.execute_plan(plan) @@ -606,9 +520,7 @@ class TestExecution(TestFabric, unittest.TestCase): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) - uniq_urls = SqlEngineNode( - ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB - ) + uniq_urls = SqlEngineNode(ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB) uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2) uniq_url_count = SqlEngineNode( ctx, @@ -637,9 +549,7 @@ class TestExecution(TestFabric, unittest.TestCase): memory_limit=1 * GB, ) with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: - data_sink = DataSinkNode( - ctx, (arrow_compute, arrow_stream), output_path=output_dir - ) + data_sink = DataSinkNode(ctx, (arrow_compute, arrow_stream), output_path=output_dir) plan = LogicalPlan(ctx, data_sink) self.execute_plan( plan, @@ -652,17 +562,11 @@ class TestExecution(TestFabric, unittest.TestCase): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) - nonzero_exitcode = PythonScriptNode( - ctx, (data_files,), process_func=nonzero_exit_code - ) + nonzero_exitcode = PythonScriptNode(ctx, (data_files,), process_func=nonzero_exit_code) plan = LogicalPlan(ctx, nonzero_exitcode) - exec_plan = self.execute_plan( - plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False - ) + exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False) self.assertFalse(exec_plan.successful) - exec_plan = self.execute_plan( - plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True - ) + exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True) self.assertTrue(exec_plan.successful) def test_manifest_only_data_sink(self): @@ -675,9 +579,7 @@ class TestExecution(TestFabric, unittest.TestCase): dataset = ParquetDataSet(filenames) data_files = DataSourceNode(ctx, dataset) data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512) - data_sink = DataSinkNode( - ctx, (data_partitions,), output_path=output_dir, manifest_only=True - ) + data_sink = DataSinkNode(ctx, (data_partitions,), output_path=output_dir, manifest_only=True) plan = LogicalPlan(ctx, data_sink) self.execute_plan(plan) @@ -734,12 +636,8 @@ class TestExecution(TestFabric, unittest.TestCase): dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10) - url_counts = SqlEngineNode( - ctx, (data_partitions,), r"select count(url) as cnt from {0}" - ) - distinct_url_counts = SqlEngineNode( - ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}" - ) + url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(url) as cnt from {0}") + distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}") merged_counts = DataSetPartitionNode( ctx, ( @@ -764,9 +662,7 @@ class TestExecution(TestFabric, unittest.TestCase): r"select count(url) as cnt from {0}", output_name="url_counts", ) - distinct_url_counts = SqlEngineNode( - ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}" - ) + distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}") merged_counts = DataSetPartitionNode( ctx, ( @@ -787,24 +683,14 @@ class TestExecution(TestFabric, unittest.TestCase): dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10) - empty_files1 = PythonScriptNode( - ctx, (data_partitions,), process_func=empty_file - ) - empty_files2 = PythonScriptNode( - ctx, (data_partitions,), process_func=empty_file - ) + empty_files1 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file) + empty_files2 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file) link_path = os.path.join(self.runtime_ctx.output_root, "link") copy_path = os.path.join(self.runtime_ctx.output_root, "copy") copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input") - data_link = DataSinkNode( - ctx, (empty_files1, empty_files2), type="link", output_path=link_path - ) - data_copy = DataSinkNode( - ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path - ) - data_copy_input = DataSinkNode( - ctx, (data_partitions,), type="copy", output_path=copy_input_path - ) + data_link = DataSinkNode(ctx, (empty_files1, empty_files2), type="link", output_path=link_path) + data_copy = DataSinkNode(ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path) + data_copy_input = DataSinkNode(ctx, (data_partitions,), type="copy", output_path=copy_input_path) plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input))) self.execute_plan(plan) @@ -813,33 +699,19 @@ class TestExecution(TestFabric, unittest.TestCase): self.assertEqual(21, len(os.listdir(copy_path))) # file name should not be modified if no conflict self.assertEqual( - set( - filename - for filename in os.listdir("tests/data/mock_urls") - if filename.endswith(".parquet") - ), - set( - filename - for filename in os.listdir(copy_input_path) - if filename.endswith(".parquet") - ), + set(filename for filename in os.listdir("tests/data/mock_urls") if filename.endswith(".parquet")), + set(filename for filename in os.listdir(copy_input_path) if filename.endswith(".parquet")), ) def test_literal_datasets_as_data_sources(self): ctx = Context() num_rows = 10 query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)") - table_dataset = ArrowTableDataSet( - arrow.Table.from_arrays([list(range(num_rows))], names=["i"]) - ) + table_dataset = ArrowTableDataSet(arrow.Table.from_arrays([list(range(num_rows))], names=["i"])) query_source = DataSourceNode(ctx, query_dataset) table_source = DataSourceNode(ctx, table_dataset) - query_partitions = DataSetPartitionNode( - ctx, (query_source,), npartitions=num_rows, partition_by_rows=True - ) - table_partitions = DataSetPartitionNode( - ctx, (table_source,), npartitions=num_rows, partition_by_rows=True - ) + query_partitions = DataSetPartitionNode(ctx, (query_source,), npartitions=num_rows, partition_by_rows=True) + table_partitions = DataSetPartitionNode(ctx, (table_source,), npartitions=num_rows, partition_by_rows=True) joined_rows = SqlEngineNode( ctx, (query_partitions, table_partitions), diff --git a/tests/test_fabric.py b/tests/test_fabric.py index 4fe9782..99fb89a 100644 --- a/tests/test_fabric.py +++ b/tests/test_fabric.py @@ -27,9 +27,7 @@ from tests.datagen import generate_data generate_data() -def run_scheduler( - runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue -): +def run_scheduler(runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue): runtime_ctx.initialize("scheduler") scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue))) retval = scheduler.run() @@ -130,9 +128,7 @@ class TestFabric(unittest.TestCase): process.kill() process.join() - logger.info( - f"#{i} process {process.name} exited with code {process.exitcode}" - ) + logger.info(f"#{i} process {process.name} exited with code {process.exitcode}") def start_execution( self, @@ -189,11 +185,7 @@ class TestFabric(unittest.TestCase): secs_wq_poll_interval=secs_wq_poll_interval, secs_executor_probe_interval=secs_executor_probe_interval, max_num_missed_probes=max_num_missed_probes, - fault_inject_prob=( - fault_inject_prob - if fault_inject_prob is not None - else self.fault_inject_prob - ), + fault_inject_prob=(fault_inject_prob if fault_inject_prob is not None else self.fault_inject_prob), enable_profiling=enable_profiling, enable_diagnostic_metrics=enable_diagnostic_metrics, remove_empty_parquet=remove_empty_parquet, @@ -217,9 +209,7 @@ class TestFabric(unittest.TestCase): nonzero_exitcode_as_oom=nonzero_exitcode_as_oom, ) self.latest_state = scheduler - self.executors = [ - Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors) - ] + self.executors = [Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)] self.processes = [ Process( target=run_scheduler, @@ -229,10 +219,7 @@ class TestFabric(unittest.TestCase): name="scheduler", ) ] - self.processes += [ - Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id) - for executor in self.executors - ] + self.processes += [Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id) for executor in self.executors] for process in reversed(self.processes): process.start() @@ -264,15 +251,9 @@ class TestFabric(unittest.TestCase): self.assertTrue(latest_state.success) return latest_state.exec_plan - def _load_parquet_files( - self, paths, filesystem: fsspec.AbstractFileSystem = None - ) -> arrow.Table: + def _load_parquet_files(self, paths, filesystem: fsspec.AbstractFileSystem = None) -> arrow.Table: def read_parquet_file(path): - return arrow.Table.from_batches( - parquet.ParquetFile( - path, buffer_size=16 * MB, filesystem=filesystem - ).iter_batches() - ) + return arrow.Table.from_batches(parquet.ParquetFile(path, buffer_size=16 * MB, filesystem=filesystem).iter_batches()) with ThreadPoolExecutor(16) as pool: return arrow.concat_tables(pool.map(read_parquet_file, paths)) diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py index d8b1dda..2613716 100644 --- a/tests/test_filesystem.py +++ b/tests/test_filesystem.py @@ -17,9 +17,7 @@ class TestFilesystem(TestFabric, unittest.TestCase): def test_pickle_trace(self): with self.assertRaises(TypeError) as context: - with tempfile.TemporaryDirectory( - dir=self.output_root_abspath - ) as output_dir: + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: thread = threading.Thread() pickle_path = os.path.join(output_dir, "thread.pickle") dump(thread, pickle_path) diff --git a/tests/test_logical.py b/tests/test_logical.py index a2901a2..ff08dba 100644 --- a/tests/test_logical.py +++ b/tests/test_logical.py @@ -21,19 +21,11 @@ class TestLogicalPlan(TestFabric, unittest.TestCase): def test_join_chunkmeta_inodes(self): ctx = Context() - chunkmeta_dump = DataSourceNode( - ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"]) - ) - chunkmeta_partitions = HashPartitionNode( - ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"] - ) + chunkmeta_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"])) + chunkmeta_partitions = HashPartitionNode(ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"]) - inodes_dump = DataSourceNode( - ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"]) - ) - inodes_partitions = HashPartitionNode( - ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"] - ) + inodes_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"])) + inodes_partitions = HashPartitionNode(ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"]) num_gc_chunks = SqlEngineNode( ctx, @@ -53,12 +45,8 @@ class TestLogicalPlan(TestFabric, unittest.TestCase): ctx = Context() parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_source = DataSourceNode(ctx, parquet_dataset) - partition_dim_a = EvenlyDistributedPartitionNode( - ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A" - ) - partition_dim_b = EvenlyDistributedPartitionNode( - ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B" - ) + partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A") + partition_dim_b = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B") join_two_inputs = SqlEngineNode( ctx, (partition_dim_a, partition_dim_b), @@ -73,9 +61,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase): ctx = Context() parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_source = DataSourceNode(ctx, parquet_dataset) - partition_dim_a = EvenlyDistributedPartitionNode( - ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A" - ) + partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A") partition_dim_a2 = EvenlyDistributedPartitionNode( ctx, (data_source,), @@ -94,9 +80,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase): ) plan = LogicalPlan( ctx, - DataSetPartitionNode( - ctx, (join_two_inputs1, join_two_inputs2), npartitions=1 - ), + DataSetPartitionNode(ctx, (join_two_inputs1, join_two_inputs2), npartitions=1), ) logger.info(str(plan)) with self.assertRaises(AssertionError) as context: diff --git a/tests/test_partition.py b/tests/test_partition.py index ac7704c..8ff6898 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -31,9 +31,7 @@ from tests.test_fabric import TestFabric class CalculatePartitionFromFilename(UserDefinedPartitionNode): def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]: - partitioned_datasets: List[ParquetDataSet] = [ - ParquetDataSet([]) for _ in range(self.npartitions) - ] + partitioned_datasets: List[ParquetDataSet] = [ParquetDataSet([]) for _ in range(self.npartitions)] for path in dataset.resolved_paths: partition_idx = hash(path) % self.npartitions partitioned_datasets[partition_idx].paths.append(path) @@ -45,9 +43,7 @@ class TestPartition(TestFabric, unittest.TestCase): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=dataset.num_files - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files) count_rows = SqlEngineNode( ctx, (data_partitions,), @@ -62,9 +58,7 @@ class TestPartition(TestFabric, unittest.TestCase): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True) count_rows = SqlEngineNode( ctx, (data_partitions,), @@ -74,18 +68,14 @@ class TestPartition(TestFabric, unittest.TestCase): ) plan = LogicalPlan(ctx, count_rows) exec_plan = self.execute_plan(plan, num_executors=5) - self.assertEqual( - exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows - ) + self.assertEqual(exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows) def test_empty_dataset_partition(self): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) # create more partitions than files - data_partitions = EvenlyDistributedPartitionNode( - ctx, (data_files,), npartitions=dataset.num_files * 2 - ) + data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=dataset.num_files * 2) data_partitions.max_num_producer_tasks = 3 unique_urls = SqlEngineNode( ctx, @@ -95,9 +85,7 @@ class TestPartition(TestFabric, unittest.TestCase): memory_limit=1 * GB, ) # nested partition - nested_partitioned_urls = EvenlyDistributedPartitionNode( - ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True - ) + nested_partitioned_urls = EvenlyDistributedPartitionNode(ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True) parsed_urls = ArrowComputeNode( ctx, (nested_partitioned_urls,), @@ -106,18 +94,14 @@ class TestPartition(TestFabric, unittest.TestCase): memory_limit=1 * GB, ) plan = LogicalPlan(ctx, parsed_urls) - final_output = self.execute_plan( - plan, remove_empty_parquet=True, skip_task_with_empty_input=True - ).final_output + final_output = self.execute_plan(plan, remove_empty_parquet=True, skip_task_with_empty_input=True).final_output self.assertTrue(isinstance(final_output, ParquetDataSet)) self.assertEqual(dataset.num_rows, final_output.num_rows) def test_hash_partition(self): for engine_type in ("duckdb", "arrow"): for partition_by_rows in (False, True): - for hive_partitioning in ( - (False, True) if engine_type == "duckdb" else (False,) - ): + for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,): with self.subTest( engine_type=engine_type, partition_by_rows=partition_by_rows, @@ -155,26 +139,16 @@ class TestPartition(TestFabric, unittest.TestCase): exec_plan = self.execute_plan(plan) self.assertEqual( dataset.num_rows, - pc.sum( - exec_plan.final_output.to_arrow_table().column( - "row_count" - ) - ).as_py(), + pc.sum(exec_plan.final_output.to_arrow_table().column("row_count")).as_py(), + ) + self.assertEqual( + npartitions, + len(exec_plan.final_output.load_partitioned_datasets(npartitions, DATA_PARTITION_COLUMN_NAME)), ) self.assertEqual( npartitions, len( - exec_plan.final_output.load_partitioned_datasets( - npartitions, DATA_PARTITION_COLUMN_NAME - ) - ), - ) - self.assertEqual( - npartitions, - len( - exec_plan.get_output( - "hash_partitions" - ).load_partitioned_datasets( + exec_plan.get_output("hash_partitions").load_partitioned_datasets( npartitions, DATA_PARTITION_COLUMN_NAME, hive_partitioning, @@ -185,9 +159,7 @@ class TestPartition(TestFabric, unittest.TestCase): def test_empty_hash_partition(self): for engine_type in ("duckdb", "arrow"): for partition_by_rows in (False, True): - for hive_partitioning in ( - (False, True) if engine_type == "duckdb" else (False,) - ): + for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,): with self.subTest( engine_type=engine_type, partition_by_rows=partition_by_rows, @@ -199,9 +171,7 @@ class TestPartition(TestFabric, unittest.TestCase): npartitions = 3 npartitions_nested = 4 num_rows = 1 - head_rows = SqlEngineNode( - ctx, (data_files,), f"select * from {{0}} limit {num_rows}" - ) + head_rows = SqlEngineNode(ctx, (data_files,), f"select * from {{0}} limit {num_rows}") data_partitions = DataSetPartitionNode( ctx, (head_rows,), @@ -241,53 +211,31 @@ class TestPartition(TestFabric, unittest.TestCase): memory_limit=1 * GB, ) plan = LogicalPlan(ctx, select_every_row) - exec_plan = self.execute_plan( - plan, skip_task_with_empty_input=True - ) + exec_plan = self.execute_plan(plan, skip_task_with_empty_input=True) self.assertEqual(num_rows, exec_plan.final_output.num_rows) self.assertEqual( npartitions, - len( - exec_plan.final_output.load_partitioned_datasets( - npartitions, "hash_partitions" - ) - ), + len(exec_plan.final_output.load_partitioned_datasets(npartitions, "hash_partitions")), ) self.assertEqual( npartitions_nested, - len( - exec_plan.final_output.load_partitioned_datasets( - npartitions_nested, "nested_hash_partitions" - ) - ), + len(exec_plan.final_output.load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")), ) self.assertEqual( npartitions, - len( - exec_plan.get_output( - "hash_partitions" - ).load_partitioned_datasets( - npartitions, "hash_partitions" - ) - ), + len(exec_plan.get_output("hash_partitions").load_partitioned_datasets(npartitions, "hash_partitions")), ) self.assertEqual( npartitions_nested, len( - exec_plan.get_output( - "nested_hash_partitions" - ).load_partitioned_datasets( - npartitions_nested, "nested_hash_partitions" - ) + exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(npartitions_nested, "nested_hash_partitions") ), ) if hive_partitioning: self.assertEqual( npartitions, len( - exec_plan.get_output( - "hash_partitions" - ).load_partitioned_datasets( + exec_plan.get_output("hash_partitions").load_partitioned_datasets( npartitions, "hash_partitions", hive_partitioning=True, @@ -297,9 +245,7 @@ class TestPartition(TestFabric, unittest.TestCase): self.assertEqual( npartitions_nested, len( - exec_plan.get_output( - "nested_hash_partitions" - ).load_partitioned_datasets( + exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets( npartitions_nested, "nested_hash_partitions", hive_partitioning=True, @@ -341,19 +287,11 @@ class TestPartition(TestFabric, unittest.TestCase): exec_plan = self.execute_plan(plan) self.assertEqual( npartitions, - len( - exec_plan.final_output.load_partitioned_datasets( - npartitions, data_partition_column - ) - ), + len(exec_plan.final_output.load_partitioned_datasets(npartitions, data_partition_column)), ) self.assertEqual( npartitions, - len( - exec_plan.get_output("input_partitions").load_partitioned_datasets( - npartitions, data_partition_column, hive_partitioning - ) - ), + len(exec_plan.get_output("input_partitions").load_partitioned_datasets(npartitions, data_partition_column, hive_partitioning)), ) return exec_plan @@ -376,12 +314,8 @@ class TestPartition(TestFabric, unittest.TestCase): ) ctx = Context() - output1 = DataSourceNode( - ctx, dataset=exec_plan1.get_output("input_partitions") - ) - output2 = DataSourceNode( - ctx, dataset=exec_plan2.get_output("input_partitions") - ) + output1 = DataSourceNode(ctx, dataset=exec_plan1.get_output("input_partitions")) + output2 = DataSourceNode(ctx, dataset=exec_plan2.get_output("input_partitions")) split_urls1 = LoadPartitionedDataSetNode( ctx, (output1,), @@ -411,16 +345,8 @@ class TestPartition(TestFabric, unittest.TestCase): plan = LogicalPlan(ctx, split_urls3) exec_plan3 = self.execute_plan(plan) # load each partition as arrow table and compare - final_output_partitions1 = ( - exec_plan1.final_output.load_partitioned_datasets( - npartitions, data_partition_column - ) - ) - final_output_partitions3 = ( - exec_plan3.final_output.load_partitioned_datasets( - npartitions, data_partition_column - ) - ) + final_output_partitions1 = exec_plan1.final_output.load_partitioned_datasets(npartitions, data_partition_column) + final_output_partitions3 = exec_plan3.final_output.load_partitioned_datasets(npartitions, data_partition_column) self.assertEqual(npartitions, len(final_output_partitions3)) for x, y in zip(final_output_partitions1, final_output_partitions3): self._compare_arrow_tables(x.to_arrow_table(), y.to_arrow_table()) @@ -433,9 +359,7 @@ class TestPartition(TestFabric, unittest.TestCase): SqlEngineNode.default_cpu_limit = 1 SqlEngineNode.default_memory_limit = 1 * GB initial_reduce = r"select host, count(*) as cnt from {0} group by host" - combine_reduce_results = ( - r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host" - ) + combine_reduce_results = r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host" join_query = r"select host, cnt from {0} where (exists (select * from {1} where {1}.host = {0}.host)) and (exists (select * from {2} where {2}.host = {0}.host))" partition_by_hosts = HashPartitionNode( @@ -496,11 +420,7 @@ class TestPartition(TestFabric, unittest.TestCase): url_count_by_3dims = SqlEngineNode(ctx, (partitioned_3dims,), initial_reduce) url_count_by_hosts_x_urls2 = SqlEngineNode( ctx, - ( - ConsolidateNode( - ctx, url_count_by_3dims, ["host_partition", "url_partition"] - ), - ), + (ConsolidateNode(ctx, url_count_by_3dims, ["host_partition", "url_partition"]),), combine_reduce_results, output_name="url_count_by_hosts_x_urls2", ) @@ -524,9 +444,7 @@ class TestPartition(TestFabric, unittest.TestCase): output_name="join_count_by_hosts_x_urls2", ) - union_url_count_by_hosts = UnionNode( - ctx, (url_count_by_hosts1, url_count_by_hosts2) - ) + union_url_count_by_hosts = UnionNode(ctx, (url_count_by_hosts1, url_count_by_hosts2)) union_url_count_by_hosts_x_urls = UnionNode( ctx, ( @@ -576,18 +494,14 @@ class TestPartition(TestFabric, unittest.TestCase): ctx = Context() parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_source = DataSourceNode(ctx, parquet_files) - file_partitions1 = CalculatePartitionFromFilename( - ctx, (data_source,), npartitions=3, dimension="by_filename_hash1" - ) + file_partitions1 = CalculatePartitionFromFilename(ctx, (data_source,), npartitions=3, dimension="by_filename_hash1") url_count1 = SqlEngineNode( ctx, (file_partitions1,), r"select host, count(*) as cnt from {0} group by host", output_name="url_count1", ) - file_partitions2 = CalculatePartitionFromFilename( - ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2" - ) + file_partitions2 = CalculatePartitionFromFilename(ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2") url_count2 = SqlEngineNode( ctx, (file_partitions2,), @@ -606,9 +520,7 @@ class TestPartition(TestFabric, unittest.TestCase): ctx = Context() parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_source = DataSourceNode(ctx, parquet_dataset) - evenly_dist_data_source = EvenlyDistributedPartitionNode( - ctx, (data_source,), npartitions=parquet_dataset.num_files - ) + evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files) parquet_datasets = [ParquetDataSet([p]) for p in parquet_dataset.resolved_paths] partitioned_data_source = UserPartitionedDataSourceNode(ctx, parquet_datasets) @@ -631,9 +543,7 @@ class TestPartition(TestFabric, unittest.TestCase): memory_limit=1 * GB, ) - plan = LogicalPlan( - ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2]) - ) + plan = LogicalPlan(ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2])) exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True) self._compare_arrow_tables( exec_plan.get_output("url_count_by_host1").to_arrow_table(), @@ -647,9 +557,7 @@ class TestPartition(TestFabric, unittest.TestCase): ctx = Context() parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_source = DataSourceNode(ctx, parquet_dataset) - evenly_dist_data_source = EvenlyDistributedPartitionNode( - ctx, (data_source,), npartitions=parquet_dataset.num_files - ) + evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files) sql_query = SqlEngineNode( ctx, (evenly_dist_data_source,), diff --git a/tests/test_plan.py b/tests/test_plan.py index 39f6fa8..c02717b 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -65,6 +65,4 @@ def test_fstest(sp: Session): def test_sort_mock_urls_v2(sp: Session): - sort_mock_urls_v2( - sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3 - ) + sort_mock_urls_v2(sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 85fb807..0a8a86a 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -21,9 +21,7 @@ from tests.test_fabric import TestFabric class RandomSleepTask(PythonScriptTask): - def __init__( - self, *args, sleep_secs: float, fail_first_try: bool, **kwargs - ) -> None: + def __init__(self, *args, sleep_secs: float, fail_first_try: bool, **kwargs) -> None: super().__init__(*args, **kwargs) self.sleep_secs = sleep_secs self.fail_first_try = fail_first_try @@ -58,24 +56,16 @@ class RandomSleepNode(PythonScriptNode): self.fail_first_try = fail_first_try def spawn(self, *args, **kwargs) -> RandomSleepTask: - sleep_secs = ( - random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs - ) - return RandomSleepTask( - *args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try - ) + sleep_secs = random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs + return RandomSleepTask(*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try) class TestScheduler(TestFabric, unittest.TestCase): - def create_random_sleep_plan( - self, npartitions, max_sleep_secs, fail_first_try=False - ): + def create_random_sleep_plan(self, npartitions, max_sleep_secs, fail_first_try=False): ctx = Context() dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) data_files = DataSourceNode(ctx, dataset) - data_partitions = DataSetPartitionNode( - ctx, (data_files,), npartitions=npartitions, partition_by_rows=True - ) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions, partition_by_rows=True) random_sleep = RandomSleepNode( ctx, (data_partitions,), @@ -87,13 +77,8 @@ class TestScheduler(TestFabric, unittest.TestCase): def check_executor_state(self, target_state: ExecutorState, nloops=200): for _ in range(nloops): latest_sched_state = self.get_latest_sched_state() - if any( - executor.state == target_state - for executor in latest_sched_state.remote_executors - ): - logger.info( - f"found {target_state} executor in: {latest_sched_state.remote_executors}" - ) + if any(executor.state == target_state for executor in latest_sched_state.remote_executors): + logger.info(f"found {target_state} executor in: {latest_sched_state.remote_executors}") break time.sleep(0.1) else: @@ -121,9 +106,7 @@ class TestScheduler(TestFabric, unittest.TestCase): latest_sched_state = self.get_latest_sched_state() self.check_executor_state(ExecutorState.GOOD) - for i, (executor, process) in enumerate( - random.sample(list(zip(executors, processes[1:])), k=num_fail) - ): + for i, (executor, process) in enumerate(random.sample(list(zip(executors, processes[1:])), k=num_fail)): if i % 2 == 0: logger.warning(f"kill executor: {executor}") process.kill() @@ -165,9 +148,7 @@ class TestScheduler(TestFabric, unittest.TestCase): self.assertGreater(len(latest_sched_state.abandoned_tasks), 0) def test_stop_executor_on_failure(self): - plan = self.create_random_sleep_plan( - npartitions=3, max_sleep_secs=5, fail_first_try=True - ) + plan = self.create_random_sleep_plan(npartitions=3, max_sleep_secs=5, fail_first_try=True) exec_plan = self.execute_plan( plan, num_executors=5, diff --git a/tests/test_session.py b/tests/test_session.py index 83df07a..510934f 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -5,9 +5,7 @@ from smallpond.dataframe import Session def test_shutdown_cleanup(sp: Session): assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should exist" - assert os.path.exists( - sp._runtime_ctx.staging_root - ), "staging directory should exist" + assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should exist" assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should exist" # create some tasks and complete them @@ -16,15 +14,9 @@ def test_shutdown_cleanup(sp: Session): sp.shutdown() # shutdown should clean up directories - assert not os.path.exists( - sp._runtime_ctx.queue_root - ), "queue directory should be cleared" - assert not os.path.exists( - sp._runtime_ctx.staging_root - ), "staging directory should be cleared" - assert not os.path.exists( - sp._runtime_ctx.temp_root - ), "temp directory should be cleared" + assert not os.path.exists(sp._runtime_ctx.queue_root), "queue directory should be cleared" + assert not os.path.exists(sp._runtime_ctx.staging_root), "staging directory should be cleared" + assert not os.path.exists(sp._runtime_ctx.temp_root), "temp directory should be cleared" with open(sp._runtime_ctx.job_status_path) as fin: assert "success" in fin.read(), "job status should be success" @@ -41,14 +33,8 @@ def test_shutdown_no_cleanup_on_failure(sp: Session): sp.shutdown() # shutdown should not clean up directories - assert os.path.exists( - sp._runtime_ctx.queue_root - ), "queue directory should not be cleared" - assert os.path.exists( - sp._runtime_ctx.staging_root - ), "staging directory should not be cleared" - assert os.path.exists( - sp._runtime_ctx.temp_root - ), "temp directory should not be cleared" + assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should not be cleared" + assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should not be cleared" + assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should not be cleared" with open(sp._runtime_ctx.job_status_path) as fin: assert "failure" in fin.read(), "job status should be failure" diff --git a/tests/test_workqueue.py b/tests/test_workqueue.py index b9072e0..7fc8641 100644 --- a/tests/test_workqueue.py +++ b/tests/test_workqueue.py @@ -85,9 +85,7 @@ class WorkQueueTestBase(object): def test_multi_consumers(self): numConsumers = 10 numItems = 200 - result = self.pool.starmap_async( - consumer, [(self.wq, id) for id in range(numConsumers)] - ) + result = self.pool.starmap_async(consumer, [(self.wq, id) for id in range(numConsumers)]) producer(self.wq, 0, numItems, numConsumers) logger.info("waiting for result")