mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
reformat code with --line-length=150 (#18)
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
|
||||
- name: Check code formatting
|
||||
run: |
|
||||
black --check .
|
||||
black --check --line-length=150 .
|
||||
|
||||
# - name: Check typos
|
||||
# uses: crate-ci/typos@v1.29.10
|
||||
|
||||
@@ -64,9 +64,7 @@ def main():
|
||||
driver = Driver()
|
||||
driver.add_argument("-i", "--input_paths", nargs="+")
|
||||
driver.add_argument("-n", "--npartitions", type=int, default=None)
|
||||
driver.add_argument(
|
||||
"-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream")
|
||||
)
|
||||
driver.add_argument("-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream"))
|
||||
driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024)
|
||||
driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024)
|
||||
driver.add_argument("-o", "--output_name", default="data")
|
||||
|
||||
@@ -73,26 +73,18 @@ def generate_records(
|
||||
subprocess.run(gensort_cmd.split()).check_returncode()
|
||||
runtime_task.add_elapsed_time("generate records (secs)")
|
||||
shm_file.seek(0)
|
||||
buffer = arrow.py_buffer(
|
||||
shm_file.read(record_count * record_nbytes)
|
||||
)
|
||||
buffer = arrow.py_buffer(shm_file.read(record_count * record_nbytes))
|
||||
runtime_task.add_elapsed_time("read records (secs)")
|
||||
# https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout
|
||||
records = arrow.Array.from_buffers(
|
||||
arrow.binary(record_nbytes), record_count, [None, buffer]
|
||||
)
|
||||
records = arrow.Array.from_buffers(arrow.binary(record_nbytes), record_count, [None, buffer])
|
||||
keys = pc.binary_slice(records, 0, key_nbytes)
|
||||
# get first 2 bytes and convert to big-endian uint16
|
||||
binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary())
|
||||
reversed_prefix = pc.binary_reverse(binary_prefix).cast(
|
||||
arrow.binary(2)
|
||||
)
|
||||
reversed_prefix = pc.binary_reverse(binary_prefix).cast(arrow.binary(2))
|
||||
uint16_prefix = reversed_prefix.view(arrow.uint16())
|
||||
buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits)
|
||||
runtime_task.add_elapsed_time("build arrow table (secs)")
|
||||
yield arrow.Table.from_arrays(
|
||||
[buckets, keys, records], schema=schema
|
||||
)
|
||||
yield arrow.Table.from_arrays([buckets, keys, records], schema=schema)
|
||||
yield StreamOutput(
|
||||
schema.empty_table(),
|
||||
batch_indices=[batch_idx],
|
||||
@@ -108,9 +100,7 @@ def sort_records(
|
||||
write_io_nbytes=500 * MB,
|
||||
) -> bool:
|
||||
runtime_task: PythonScriptTask = runtime_ctx.task
|
||||
data_file_path = os.path.join(
|
||||
runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat"
|
||||
)
|
||||
data_file_path = os.path.join(runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat")
|
||||
|
||||
if sort_engine == "polars":
|
||||
input_data = polars.read_parquet(
|
||||
@@ -134,9 +124,7 @@ def sort_records(
|
||||
record_arrays = sorted_table.column("records").chunks
|
||||
runtime_task.add_elapsed_time("convert to chunks (secs)")
|
||||
elif sort_engine == "duckdb":
|
||||
with duckdb.connect(
|
||||
database=":memory:", config={"allow_unsigned_extensions": "true"}
|
||||
) as conn:
|
||||
with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn:
|
||||
runtime_task.prepare_connection(conn)
|
||||
input_views = runtime_task.create_input_views(conn, input_datasets)
|
||||
sql_query = "select records from {0} order by keys".format(*input_views)
|
||||
@@ -154,8 +142,7 @@ def sort_records(
|
||||
buffer_mem = memoryview(values)
|
||||
|
||||
total_write_nbytes = sum(
|
||||
fout.write(buffer_mem[offset : offset + write_io_nbytes])
|
||||
for offset in range(0, len(buffer_mem), write_io_nbytes)
|
||||
fout.write(buffer_mem[offset : offset + write_io_nbytes]) for offset in range(0, len(buffer_mem), write_io_nbytes)
|
||||
)
|
||||
assert total_write_nbytes == len(buffer_mem)
|
||||
|
||||
@@ -164,16 +151,10 @@ def sort_records(
|
||||
return True
|
||||
|
||||
|
||||
def validate_records(
|
||||
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
|
||||
) -> bool:
|
||||
def validate_records(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
|
||||
for data_path in input_datasets[0].resolved_paths:
|
||||
summary_path = os.path.join(
|
||||
output_path, PurePath(data_path).with_suffix(".sum").name
|
||||
)
|
||||
cmdstr = (
|
||||
f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
|
||||
)
|
||||
summary_path = os.path.join(output_path, PurePath(data_path).with_suffix(".sum").name)
|
||||
cmdstr = f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
|
||||
logging.debug(f"running command: {cmdstr}")
|
||||
result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8")
|
||||
if result.stderr:
|
||||
@@ -185,9 +166,7 @@ def validate_records(
|
||||
return True
|
||||
|
||||
|
||||
def validate_summary(
|
||||
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
|
||||
) -> bool:
|
||||
def validate_summary(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
|
||||
concated_summary_path = os.path.join(output_path, "merged.sum")
|
||||
with open(concated_summary_path, "wb") as fout:
|
||||
for path in input_datasets[0].resolved_paths:
|
||||
@@ -224,22 +203,13 @@ def generate_random_records(
|
||||
)
|
||||
|
||||
range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)]
|
||||
range_num_records = [
|
||||
min(total_num_records, record_range_size * (range_idx + 1)) - begin_at
|
||||
for range_idx, begin_at in enumerate(range_begin_at)
|
||||
]
|
||||
range_num_records = [min(total_num_records, record_range_size * (range_idx + 1)) - begin_at for range_idx, begin_at in enumerate(range_begin_at)]
|
||||
assert sum(range_num_records) == total_num_records
|
||||
record_range = DataSourceNode(
|
||||
ctx,
|
||||
ArrowTableDataSet(
|
||||
arrow.Table.from_arrays(
|
||||
[range_begin_at, range_num_records], names=["begin_at", "num_records"]
|
||||
)
|
||||
),
|
||||
)
|
||||
record_range_partitions = DataSetPartitionNode(
|
||||
ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True
|
||||
ArrowTableDataSet(arrow.Table.from_arrays([range_begin_at, range_num_records], names=["begin_at", "num_records"])),
|
||||
)
|
||||
record_range_partitions = DataSetPartitionNode(ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True)
|
||||
|
||||
random_records = ArrowStreamNode(
|
||||
ctx,
|
||||
@@ -288,9 +258,7 @@ def gray_sort_benchmark(
|
||||
if input_paths:
|
||||
input_dataset = ParquetDataSet(input_paths)
|
||||
input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths)
|
||||
logging.warning(
|
||||
f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files"
|
||||
)
|
||||
logging.warning(f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files")
|
||||
random_records = DataSourceNode(ctx, input_dataset)
|
||||
else:
|
||||
random_records = generate_random_records(
|
||||
@@ -335,12 +303,8 @@ def gray_sort_benchmark(
|
||||
process_func=validate_records,
|
||||
output_name="partitioned_summaries",
|
||||
)
|
||||
merged_summaries = DataSetPartitionNode(
|
||||
ctx, (partitioned_summaries,), npartitions=1
|
||||
)
|
||||
final_check = PythonScriptNode(
|
||||
ctx, (merged_summaries,), process_func=validate_summary
|
||||
)
|
||||
merged_summaries = DataSetPartitionNode(ctx, (partitioned_summaries,), npartitions=1)
|
||||
final_check = PythonScriptNode(ctx, (merged_summaries,), process_func=validate_summary)
|
||||
root = final_check
|
||||
else:
|
||||
root = sorted_records
|
||||
@@ -359,17 +323,11 @@ def main():
|
||||
driver.add_argument("-n", "--num_data_partitions", type=int, default=None)
|
||||
driver.add_argument("-t", "--num_sort_partitions", type=int, default=None)
|
||||
driver.add_argument("-i", "--input_paths", nargs="+", default=[])
|
||||
driver.add_argument(
|
||||
"-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow")
|
||||
)
|
||||
driver.add_argument(
|
||||
"-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars")
|
||||
)
|
||||
driver.add_argument("-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow"))
|
||||
driver.add_argument("-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars"))
|
||||
driver.add_argument("-H", "--hive_partitioning", action="store_true")
|
||||
driver.add_argument("-V", "--validate_results", action="store_true")
|
||||
driver.add_argument(
|
||||
"-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit
|
||||
)
|
||||
driver.add_argument("-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit)
|
||||
driver.add_argument(
|
||||
"-M",
|
||||
"--shuffle_memory_limit",
|
||||
@@ -378,12 +336,8 @@ def main():
|
||||
)
|
||||
driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8)
|
||||
driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None)
|
||||
driver.add_argument(
|
||||
"-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False)
|
||||
)
|
||||
driver.add_argument(
|
||||
"-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total
|
||||
)
|
||||
driver.add_argument("-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False))
|
||||
driver.add_argument("-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total)
|
||||
driver.add_argument("-CP", "--parquet_compression", default=None)
|
||||
driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None)
|
||||
|
||||
@@ -393,16 +347,9 @@ def main():
|
||||
total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node
|
||||
memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node
|
||||
|
||||
user_args.sort_cpu_limit = (
|
||||
1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
|
||||
)
|
||||
sort_memory_limit = (
|
||||
user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
|
||||
)
|
||||
user_args.total_data_nbytes = (
|
||||
user_args.total_data_nbytes
|
||||
or max(1, driver_args.num_executors) * user_args.memory_per_node
|
||||
)
|
||||
user_args.sort_cpu_limit = 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
|
||||
sort_memory_limit = user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
|
||||
user_args.total_data_nbytes = user_args.total_data_nbytes or max(1, driver_args.num_executors) * user_args.memory_per_node
|
||||
user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2
|
||||
user_args.num_sort_partitions = user_args.num_sort_partitions or max(
|
||||
total_num_cpus // user_args.sort_cpu_limit,
|
||||
|
||||
@@ -70,18 +70,12 @@ def main():
|
||||
driver.add_argument("-i", "--input_paths", nargs="+", required=True)
|
||||
driver.add_argument("-n", "--npartitions", type=int, default=None)
|
||||
driver.add_argument("-c", "--hash_columns", nargs="+", required=True)
|
||||
driver.add_argument(
|
||||
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
|
||||
)
|
||||
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
|
||||
driver.add_argument("-S", "--partition_stats", action="store_true")
|
||||
driver.add_argument("-W", "--use_parquet_writer", action="store_true")
|
||||
driver.add_argument("-H", "--hive_partitioning", action="store_true")
|
||||
driver.add_argument(
|
||||
"-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit
|
||||
)
|
||||
driver.add_argument(
|
||||
"-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit
|
||||
)
|
||||
driver.add_argument("-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit)
|
||||
driver.add_argument("-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit)
|
||||
driver.add_argument("-NC", "--cpus_per_node", type=int, default=192)
|
||||
driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB)
|
||||
|
||||
|
||||
@@ -29,9 +29,7 @@ def urls_sort_benchmark(
|
||||
delim=r"\t",
|
||||
)
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=num_data_partitions
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=num_data_partitions)
|
||||
|
||||
imported_urls = SqlEngineNode(
|
||||
ctx,
|
||||
@@ -80,16 +78,10 @@ def urls_sort_benchmark_v2(
|
||||
sort_cpu_limit=8,
|
||||
sort_memory_limit=None,
|
||||
):
|
||||
dataset = sp.read_csv(
|
||||
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
|
||||
)
|
||||
dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t")
|
||||
data_partitions = dataset.repartition(num_data_partitions)
|
||||
urls_partitions = data_partitions.repartition(
|
||||
num_hash_partitions, hash_by="urlstr", engine_type=engine_type
|
||||
)
|
||||
sorted_urls = urls_partitions.partial_sort(
|
||||
by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit
|
||||
)
|
||||
urls_partitions = data_partitions.repartition(num_hash_partitions, hash_by="urlstr", engine_type=engine_type)
|
||||
sorted_urls = urls_partitions.partial_sort(by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit)
|
||||
sorted_urls.write_parquet(output_path)
|
||||
|
||||
|
||||
@@ -106,12 +98,8 @@ def main():
|
||||
num_nodes = driver_args.num_executors
|
||||
cpus_per_node = 120
|
||||
partition_rounds = 2
|
||||
user_args.num_data_partitions = (
|
||||
user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
|
||||
)
|
||||
user_args.num_hash_partitions = (
|
||||
user_args.num_hash_partitions or num_nodes * cpus_per_node
|
||||
)
|
||||
user_args.num_data_partitions = user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
|
||||
user_args.num_hash_partitions = user_args.num_hash_partitions or num_nodes * cpus_per_node
|
||||
|
||||
# v1
|
||||
plan = urls_sort_benchmark(**vars(user_args))
|
||||
|
||||
@@ -80,9 +80,7 @@ def check_data(actual: bytes, expected: bytes, offset: int) -> None:
|
||||
)
|
||||
expected = expected[index : index + 16]
|
||||
actual = actual[index : index + 16]
|
||||
raise ValueError(
|
||||
f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}"
|
||||
)
|
||||
raise ValueError(f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}")
|
||||
|
||||
|
||||
def generate_data(offset: int, length: int) -> bytes:
|
||||
@@ -92,16 +90,10 @@ def generate_data(offset: int, length: int) -> bytes:
|
||||
"""
|
||||
istart = offset // 4
|
||||
iend = (offset + length + 3) // 4
|
||||
return (
|
||||
np.arange(istart, iend)
|
||||
.astype(np.uint32)
|
||||
.tobytes()[offset % 4 : offset % 4 + length]
|
||||
)
|
||||
return np.arange(istart, iend).astype(np.uint32).tobytes()[offset % 4 : offset % 4 + length]
|
||||
|
||||
|
||||
def iter_io_slice(
|
||||
offset: int, length: int, block_size: Union[int, Tuple[int, int]]
|
||||
) -> Iterator[Tuple[int, int]]:
|
||||
def iter_io_slice(offset: int, length: int, block_size: Union[int, Tuple[int, int]]) -> Iterator[Tuple[int, int]]:
|
||||
"""
|
||||
Generate the IO (offset, size) for the slice [offset, offset + length) with the given block size.
|
||||
`block_size` can be an integer or a range [start, end]. If a range is provided, the IO size will be randomly selected from the range.
|
||||
@@ -161,9 +153,7 @@ def fstest(
|
||||
|
||||
if output_path is not None:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
df = sp.from_items(
|
||||
[{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)]
|
||||
)
|
||||
df = sp.from_items([{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)])
|
||||
df = df.repartition(npartitions, by_rows=True)
|
||||
stats = df.map(lambda x: fswrite(x["path"], size, blocksize)).to_pandas()
|
||||
logging.info(f"write stats:\n{stats}")
|
||||
@@ -187,18 +177,14 @@ if __name__ == "__main__":
|
||||
python example/fstest.py -o 'fstest' -j 8 -s 1G -i 'fstest/*'
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-o", "--output_path", type=str, help="The output path to write data to."
|
||||
)
|
||||
parser.add_argument("-o", "--output_path", type=str, help="The output path to write data to.")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input_path",
|
||||
type=str,
|
||||
help="The input path to read data from. If -o is provided, this is ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j", "--npartitions", type=int, help="The number of parallel jobs", default=10
|
||||
)
|
||||
parser.add_argument("-j", "--npartitions", type=int, help="The number of parallel jobs", default=10)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--size",
|
||||
|
||||
@@ -52,9 +52,7 @@ def shuffle_data(
|
||||
npartitions=num_out_data_partitions,
|
||||
partition_by_rows=True,
|
||||
)
|
||||
shuffled_urls = StreamCopy(
|
||||
ctx, (repartitioned,), output_name="data_copy", cpu_limit=1
|
||||
)
|
||||
shuffled_urls = StreamCopy(ctx, (repartitioned,), output_name="data_copy", cpu_limit=1)
|
||||
|
||||
plan = LogicalPlan(ctx, shuffled_urls)
|
||||
return plan
|
||||
@@ -66,9 +64,7 @@ def main():
|
||||
driver.add_argument("-nd", "--num_data_partitions", type=int, default=1024)
|
||||
driver.add_argument("-nh", "--num_hash_partitions", type=int, default=3840)
|
||||
driver.add_argument("-no", "--num_out_data_partitions", type=int, default=1920)
|
||||
driver.add_argument(
|
||||
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
|
||||
)
|
||||
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
|
||||
driver.add_argument("-x", "--skip_hash_partition", action="store_true")
|
||||
plan = shuffle_data(**driver.get_arguments())
|
||||
driver.run(plan)
|
||||
|
||||
@@ -11,9 +11,7 @@ from smallpond.logical.node import (
|
||||
)
|
||||
|
||||
|
||||
def shuffle_mock_urls(
|
||||
input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb"
|
||||
) -> LogicalPlan:
|
||||
def shuffle_mock_urls(input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb") -> LogicalPlan:
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(input_paths)
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
@@ -61,9 +59,7 @@ def main():
|
||||
driver.add_argument("-i", "--input_paths", nargs="+")
|
||||
driver.add_argument("-n", "--npartitions", type=int, default=500)
|
||||
driver.add_argument("-s", "--sort_rand_keys", action="store_true")
|
||||
driver.add_argument(
|
||||
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
|
||||
)
|
||||
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
|
||||
|
||||
plan = shuffle_mock_urls(**driver.get_arguments())
|
||||
driver.run(plan)
|
||||
|
||||
@@ -20,9 +20,7 @@ from smallpond.logical.node import (
|
||||
|
||||
|
||||
class SortUrlsNode(ArrowComputeNode):
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
|
||||
logging.info(f"sorting urls by 'host', table shape: {input_tables[0].shape}")
|
||||
return input_tables[0].sort_by("host")
|
||||
|
||||
@@ -90,9 +88,7 @@ def sort_mock_urls(
|
||||
|
||||
def main():
|
||||
driver = Driver()
|
||||
driver.add_argument(
|
||||
"-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]
|
||||
)
|
||||
driver.add_argument("-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"])
|
||||
driver.add_argument("-n", "--npartitions", type=int, default=10)
|
||||
driver.add_argument("-e", "--engine_type", default="duckdb")
|
||||
|
||||
|
||||
@@ -5,12 +5,8 @@ import smallpond
|
||||
from smallpond.dataframe import Session
|
||||
|
||||
|
||||
def sort_mock_urls_v2(
|
||||
sp: Session, input_paths: List[str], output_path: str, npartitions: int
|
||||
):
|
||||
dataset = sp.read_csv(
|
||||
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
|
||||
).repartition(npartitions)
|
||||
def sort_mock_urls_v2(sp: Session, input_paths: List[str], output_path: str, npartitions: int):
|
||||
dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t").repartition(npartitions)
|
||||
urls = dataset.map(
|
||||
"""
|
||||
split_part(urlstr, '/', 1) as host,
|
||||
@@ -25,9 +21,7 @@ def sort_mock_urls_v2(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]
|
||||
)
|
||||
parser.add_argument("-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"])
|
||||
parser.add_argument("-o", "--output_path", type=str, default="sort_mock_urls")
|
||||
parser.add_argument("-n", "--npartitions", type=int, default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -41,9 +41,7 @@ def clamp_row_group_size(val, minval=DEFAULT_ROW_GROUP_SIZE, maxval=MAX_ROW_GROU
|
||||
return clamp_value(val, minval, maxval)
|
||||
|
||||
|
||||
def clamp_row_group_bytes(
|
||||
val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES
|
||||
):
|
||||
def clamp_row_group_bytes(val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES):
|
||||
return clamp_value(val, minval, maxval)
|
||||
|
||||
|
||||
@@ -74,10 +72,7 @@ def first_value_in_dict(d: Dict[K, V]) -> V:
|
||||
def split_into_cols(items: List[V], npartitions: int) -> List[List[V]]:
|
||||
none = object()
|
||||
chunks = [items[i : i + npartitions] for i in range(0, len(items), npartitions)]
|
||||
return [
|
||||
[x for x in col if x is not none]
|
||||
for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none)
|
||||
]
|
||||
return [[x for x in col if x is not none] for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none)]
|
||||
|
||||
|
||||
def split_into_rows(items: List[V], npartitions: int) -> List[List[V]]:
|
||||
@@ -101,10 +96,7 @@ def get_nth_partition(items: List[V], n: int, npartitions: int) -> List[V]:
|
||||
start = n * large_partition_size
|
||||
items_in_partition = items[start : start + large_partition_size]
|
||||
else:
|
||||
start = (
|
||||
large_partition_size * num_large_partitions
|
||||
+ (n - num_large_partitions) * small_partition_size
|
||||
)
|
||||
start = large_partition_size * num_large_partitions + (n - num_large_partitions) * small_partition_size
|
||||
items_in_partition = items[start : start + small_partition_size]
|
||||
return items_in_partition
|
||||
|
||||
|
||||
@@ -8,17 +8,13 @@ from smallpond.logical.node import ArrowComputeNode, ArrowStreamNode
|
||||
|
||||
|
||||
class CopyArrowTable(ArrowComputeNode):
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
|
||||
logger.info(f"copying table: {input_tables[0].num_rows} rows ...")
|
||||
return input_tables[0]
|
||||
|
||||
|
||||
class StreamCopy(ArrowStreamNode):
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
|
||||
) -> Iterable[arrow.Table]:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
|
||||
for batch in input_readers[0]:
|
||||
logger.info(f"copying batch: {batch.num_rows} rows ...")
|
||||
yield arrow.Table.from_batches([batch])
|
||||
|
||||
@@ -25,9 +25,7 @@ class LogDataSetTask(PythonScriptTask):
|
||||
|
||||
|
||||
class LogDataSet(PythonScriptNode):
|
||||
def __init__(
|
||||
self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs
|
||||
) -> None:
|
||||
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs) -> None:
|
||||
super().__init__(ctx, input_deps, **kwargs)
|
||||
self.num_rows = num_rows
|
||||
|
||||
|
||||
@@ -29,16 +29,12 @@ class ImportWarcFiles(PythonScriptNode):
|
||||
]
|
||||
)
|
||||
|
||||
def import_warc_file(
|
||||
self, warc_path: PurePath, parquet_path: PurePath
|
||||
) -> Tuple[int, int]:
|
||||
def import_warc_file(self, warc_path: PurePath, parquet_path: PurePath) -> Tuple[int, int]:
|
||||
total_size = 0
|
||||
docs = []
|
||||
|
||||
with open(warc_path, "rb") as warc_file:
|
||||
zstd_reader = zstd.ZstdDecompressor().stream_reader(
|
||||
warc_file, read_size=16 * MB
|
||||
)
|
||||
zstd_reader = zstd.ZstdDecompressor().stream_reader(warc_file, read_size=16 * MB)
|
||||
for record in ArchiveIterator(zstd_reader):
|
||||
if record.rec_type == "response":
|
||||
url = record.rec_headers.get_header("WARC-Target-URI")
|
||||
@@ -48,9 +44,7 @@ class ImportWarcFiles(PythonScriptNode):
|
||||
total_size += len(content)
|
||||
docs.append((url, domain, date, content))
|
||||
|
||||
table = arrow.Table.from_arrays(
|
||||
[arrow.array(column) for column in zip(*docs)], schema=self.schema
|
||||
)
|
||||
table = arrow.Table.from_arrays([arrow.array(column) for column in zip(*docs)], schema=self.schema)
|
||||
dump_to_parquet_files(table, parquet_path.parent, parquet_path.name)
|
||||
return len(docs), total_size
|
||||
|
||||
@@ -60,14 +54,9 @@ class ImportWarcFiles(PythonScriptNode):
|
||||
input_datasets: List[DataSet],
|
||||
output_path: str,
|
||||
) -> bool:
|
||||
warc_paths = [
|
||||
PurePath(warc_path)
|
||||
for dataset in input_datasets
|
||||
for warc_path in dataset.resolved_paths
|
||||
]
|
||||
warc_paths = [PurePath(warc_path) for dataset in input_datasets for warc_path in dataset.resolved_paths]
|
||||
parquet_paths = [
|
||||
PurePath(output_path)
|
||||
/ f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}"
|
||||
PurePath(output_path) / f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}"
|
||||
for path_index, warc_path in enumerate(warc_paths)
|
||||
]
|
||||
|
||||
@@ -75,24 +64,16 @@ class ImportWarcFiles(PythonScriptNode):
|
||||
for warc_path, parquet_path in zip(warc_paths, parquet_paths):
|
||||
try:
|
||||
doc_count, total_size = self.import_warc_file(warc_path, parquet_path)
|
||||
logger.info(
|
||||
f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'"
|
||||
)
|
||||
logger.info(f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'")
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).error(
|
||||
f"failed to import web pages from file '{warc_path}'"
|
||||
)
|
||||
logger.opt(exception=ex).error(f"failed to import web pages from file '{warc_path}'")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ExtractHtmlBody(ArrowStreamNode):
|
||||
|
||||
unicode_punctuation = "".join(
|
||||
chr(i)
|
||||
for i in range(sys.maxunicode)
|
||||
if unicodedata.category(chr(i)).startswith("P")
|
||||
)
|
||||
unicode_punctuation = "".join(chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
|
||||
separator_str = string.whitespace + string.punctuation + unicode_punctuation
|
||||
translator = str.maketrans(separator_str, " " * len(separator_str))
|
||||
|
||||
@@ -117,27 +98,19 @@ class ExtractHtmlBody(ArrowStreamNode):
|
||||
tokens.extend(self.split_string(doc.get_text(" ", strip=True).lower()))
|
||||
return tokens
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).error(
|
||||
f"failed to extract tokens from {url.as_py()}"
|
||||
)
|
||||
logger.opt(exception=ex).error(f"failed to extract tokens from {url.as_py()}")
|
||||
return []
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
|
||||
) -> Iterable[arrow.Table]:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
|
||||
for batch in input_readers[0]:
|
||||
urls, domains, dates, contents = batch.columns
|
||||
doc_tokens = []
|
||||
try:
|
||||
for i, (url, content) in enumerate(zip(urls, contents)):
|
||||
tokens = self.extract_tokens(url, content)
|
||||
logger.info(
|
||||
f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}"
|
||||
)
|
||||
logger.info(f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}")
|
||||
doc_tokens.append(tokens)
|
||||
yield arrow.Table.from_arrays(
|
||||
[urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema
|
||||
)
|
||||
yield arrow.Table.from_arrays([urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema)
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).error(f"failed to extract tokens")
|
||||
break
|
||||
|
||||
@@ -35,9 +35,7 @@ class Session(SessionBase):
|
||||
Subsequent DataFrames can reuse the tasks to avoid recomputation.
|
||||
"""
|
||||
|
||||
def read_csv(
|
||||
self, paths: Union[str, List[str]], schema: Dict[str, str], delim=","
|
||||
) -> DataFrame:
|
||||
def read_csv(self, paths: Union[str, List[str]], schema: Dict[str, str], delim=",") -> DataFrame:
|
||||
"""
|
||||
Create a DataFrame from CSV files.
|
||||
"""
|
||||
@@ -55,15 +53,11 @@ class Session(SessionBase):
|
||||
"""
|
||||
Create a DataFrame from Parquet files.
|
||||
"""
|
||||
dataset = ParquetDataSet(
|
||||
paths, columns=columns, union_by_name=union_by_name, recursive=recursive
|
||||
)
|
||||
dataset = ParquetDataSet(paths, columns=columns, union_by_name=union_by_name, recursive=recursive)
|
||||
plan = DataSourceNode(self._ctx, dataset)
|
||||
return DataFrame(self, plan)
|
||||
|
||||
def read_json(
|
||||
self, paths: Union[str, List[str]], schema: Dict[str, str]
|
||||
) -> DataFrame:
|
||||
def read_json(self, paths: Union[str, List[str]], schema: Dict[str, str]) -> DataFrame:
|
||||
"""
|
||||
Create a DataFrame from JSON files.
|
||||
"""
|
||||
@@ -115,9 +109,7 @@ class Session(SessionBase):
|
||||
c = sp.partial_sql("select * from {0} join {1} on a.id = b.id", a, b)
|
||||
"""
|
||||
|
||||
plan = SqlEngineNode(
|
||||
self._ctx, tuple(input.plan for input in inputs), query, **kwargs
|
||||
)
|
||||
plan = SqlEngineNode(self._ctx, tuple(input.plan for input in inputs), query, **kwargs)
|
||||
recompute = any(input.need_recompute for input in inputs)
|
||||
return DataFrame(self, plan, recompute=recompute)
|
||||
|
||||
@@ -177,26 +169,15 @@ class Session(SessionBase):
|
||||
"""
|
||||
Return the total number of tasks and the number of tasks that are finished.
|
||||
"""
|
||||
dataset_refs = [
|
||||
task._dataset_ref
|
||||
for tasks in self._node_to_tasks.values()
|
||||
for task in tasks
|
||||
if task._dataset_ref is not None
|
||||
]
|
||||
ready_tasks, _ = ray.wait(
|
||||
dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False
|
||||
)
|
||||
dataset_refs = [task._dataset_ref for tasks in self._node_to_tasks.values() for task in tasks if task._dataset_ref is not None]
|
||||
ready_tasks, _ = ray.wait(dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False)
|
||||
return len(dataset_refs), len(ready_tasks)
|
||||
|
||||
def _all_tasks_finished(self) -> bool:
|
||||
"""
|
||||
Check if all tasks are finished.
|
||||
"""
|
||||
dataset_refs = [
|
||||
task._dataset_ref
|
||||
for tasks in self._node_to_tasks.values()
|
||||
for task in tasks
|
||||
]
|
||||
dataset_refs = [task._dataset_ref for tasks in self._node_to_tasks.values() for task in tasks]
|
||||
try:
|
||||
ray.get(dataset_refs, timeout=0)
|
||||
except Exception:
|
||||
@@ -232,12 +213,8 @@ class DataFrame:
|
||||
# optimize the plan
|
||||
if self.optimized_plan is None:
|
||||
logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
|
||||
self.optimized_plan = Optimizer(
|
||||
exclude_nodes=set(self.session._node_to_tasks.keys())
|
||||
).visit(self.plan)
|
||||
logger.info(
|
||||
f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
|
||||
)
|
||||
self.optimized_plan = Optimizer(exclude_nodes=set(self.session._node_to_tasks.keys())).visit(self.plan)
|
||||
logger.info(f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}")
|
||||
# return the tasks if already created
|
||||
if tasks := self.session._node_to_tasks.get(self.optimized_plan):
|
||||
return tasks
|
||||
@@ -281,9 +258,7 @@ class DataFrame:
|
||||
"""
|
||||
for retry_count in range(3):
|
||||
try:
|
||||
return ray.get(
|
||||
[task.run_on_ray() for task in self._get_or_create_tasks()]
|
||||
)
|
||||
return ray.get([task.run_on_ray() for task in self._get_or_create_tasks()])
|
||||
except ray.exceptions.RuntimeEnvSetupError as e:
|
||||
# XXX: Ray may raise this error when a worker is interrupted.
|
||||
# ```
|
||||
@@ -361,9 +336,7 @@ class DataFrame:
|
||||
)
|
||||
elif hash_by is not None:
|
||||
hash_columns = [hash_by] if isinstance(hash_by, str) else hash_by
|
||||
plan = HashPartitionNode(
|
||||
self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs
|
||||
)
|
||||
plan = HashPartitionNode(self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs)
|
||||
else:
|
||||
plan = EvenlyDistributedPartitionNode(
|
||||
self.session._ctx,
|
||||
@@ -420,9 +393,7 @@ class DataFrame:
|
||||
)
|
||||
return DataFrame(self.session, plan, recompute=self.need_recompute)
|
||||
|
||||
def filter(
|
||||
self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs
|
||||
) -> DataFrame:
|
||||
def filter(self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs) -> DataFrame:
|
||||
"""
|
||||
Filter out rows that don't satisfy the given predicate.
|
||||
|
||||
@@ -453,13 +424,9 @@ class DataFrame:
|
||||
table = tables[0]
|
||||
return table.filter([func(row) for row in table.to_pylist()])
|
||||
|
||||
plan = ArrowBatchNode(
|
||||
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
|
||||
)
|
||||
plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"condition must be a SQL expression or a predicate function"
|
||||
)
|
||||
raise ValueError("condition must be a SQL expression or a predicate function")
|
||||
return DataFrame(self.session, plan, recompute=self.need_recompute)
|
||||
|
||||
def map(
|
||||
@@ -510,18 +477,14 @@ class DataFrame:
|
||||
|
||||
"""
|
||||
if isinstance(sql := sql_or_func, str):
|
||||
plan = SqlEngineNode(
|
||||
self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs
|
||||
)
|
||||
plan = SqlEngineNode(self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs)
|
||||
elif isinstance(func := sql_or_func, Callable):
|
||||
|
||||
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
|
||||
output_rows = [func(row) for row in tables[0].to_pylist()]
|
||||
return arrow.Table.from_pylist(output_rows, schema=schema)
|
||||
|
||||
plan = ArrowBatchNode(
|
||||
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
|
||||
)
|
||||
plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}")
|
||||
return DataFrame(self.session, plan, recompute=self.need_recompute)
|
||||
@@ -555,20 +518,14 @@ class DataFrame:
|
||||
"""
|
||||
if isinstance(sql := sql_or_func, str):
|
||||
|
||||
plan = SqlEngineNode(
|
||||
self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs
|
||||
)
|
||||
plan = SqlEngineNode(self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs)
|
||||
elif isinstance(func := sql_or_func, Callable):
|
||||
|
||||
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
|
||||
output_rows = [
|
||||
item for row in tables[0].to_pylist() for item in func(row)
|
||||
]
|
||||
output_rows = [item for row in tables[0].to_pylist() for item in func(row)]
|
||||
return arrow.Table.from_pylist(output_rows, schema=schema)
|
||||
|
||||
plan = ArrowBatchNode(
|
||||
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
|
||||
)
|
||||
plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}")
|
||||
return DataFrame(self.session, plan, recompute=self.need_recompute)
|
||||
@@ -642,9 +599,7 @@ class DataFrame:
|
||||
sp.wait(o1, o2)
|
||||
"""
|
||||
|
||||
plan = DataSinkNode(
|
||||
self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy"
|
||||
)
|
||||
plan = DataSinkNode(self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy")
|
||||
return DataFrame(self.session, plan, recompute=self.need_recompute)
|
||||
|
||||
# inspection
|
||||
@@ -710,6 +665,4 @@ class DataFrame:
|
||||
"""
|
||||
datasets = self._compute()
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return arrow.concat_tables(
|
||||
pool.map(lambda dataset: dataset.to_arrow_table(), datasets)
|
||||
)
|
||||
return arrow.concat_tables(pool.map(lambda dataset: dataset.to_arrow_table(), datasets))
|
||||
|
||||
@@ -28,43 +28,23 @@ class Driver(object):
|
||||
self.all_args = None
|
||||
|
||||
def _create_driver_args_parser(self):
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="driver.py", description="Smallpond Driver", add_help=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"mode", choices=["executor", "scheduler", "ray"], default="executor"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exec_id", default=socket.gethostname(), help="Unique executor id"
|
||||
)
|
||||
parser = argparse.ArgumentParser(prog="driver.py", description="Smallpond Driver", add_help=False)
|
||||
parser.add_argument("mode", choices=["executor", "scheduler", "ray"], default="executor")
|
||||
parser.add_argument("--exec_id", default=socket.gethostname(), help="Unique executor id")
|
||||
parser.add_argument("--job_id", type=str, help="Unique job id")
|
||||
parser.add_argument(
|
||||
"--job_time", type=float, help="Job create time (seconds since epoch)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job_name", default="smallpond", help="Display name of the job"
|
||||
)
|
||||
parser.add_argument("--job_time", type=float, help="Job create time (seconds since epoch)")
|
||||
parser.add_argument("--job_name", default="smallpond", help="Display name of the job")
|
||||
parser.add_argument(
|
||||
"--job_priority",
|
||||
type=int,
|
||||
help="Job priority",
|
||||
)
|
||||
parser.add_argument("--resource_group", type=str, help="Resource group")
|
||||
parser.add_argument(
|
||||
"--env_variables", nargs="*", default=[], help="Env variables for the job"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sidecars", nargs="*", default=[], help="Sidecars for the job"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tags", nargs="*", default=[], help="Tags for submitted platform task"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_image", default="default", help="Container image of platform task"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--python_venv", type=str, help="Python virtual env for the job"
|
||||
)
|
||||
parser.add_argument("--env_variables", nargs="*", default=[], help="Env variables for the job")
|
||||
parser.add_argument("--sidecars", nargs="*", default=[], help="Sidecars for the job")
|
||||
parser.add_argument("--tags", nargs="*", default=[], help="Tags for submitted platform task")
|
||||
parser.add_argument("--task_image", default="default", help="Container image of platform task")
|
||||
parser.add_argument("--python_venv", type=str, help="Python virtual env for the job")
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
@@ -257,9 +237,7 @@ class Driver(object):
|
||||
default="DEBUG",
|
||||
choices=log_level_choices,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_log_rotation", action="store_true", help="Disable log rotation"
|
||||
)
|
||||
parser.add_argument("--disable_log_rotation", action="store_true", help="Disable log rotation")
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
help="Set the output directory of final results and all nodes that have output_name but no output_path specified",
|
||||
@@ -279,9 +257,7 @@ class Driver(object):
|
||||
|
||||
def parse_arguments(self, args=None):
|
||||
if self.user_args is None or self.driver_args is None:
|
||||
args_parser = argparse.ArgumentParser(
|
||||
parents=[self.driver_args_parser, self.user_args_parser]
|
||||
)
|
||||
args_parser = argparse.ArgumentParser(parents=[self.driver_args_parser, self.user_args_parser])
|
||||
self.all_args = args_parser.parse_args(args)
|
||||
self.user_args, other_args = self.user_args_parser.parse_known_args(args)
|
||||
self.driver_args = self.driver_args_parser.parse_args(other_args)
|
||||
@@ -349,9 +325,7 @@ class Driver(object):
|
||||
DataFrame(sp, plan.root_node).compute()
|
||||
retval = True
|
||||
elif args.mode == "executor":
|
||||
assert os.path.isfile(
|
||||
args.runtime_ctx_path
|
||||
), f"cannot find runtime context: {args.runtime_ctx_path}"
|
||||
assert os.path.isfile(args.runtime_ctx_path), f"cannot find runtime context: {args.runtime_ctx_path}"
|
||||
runtime_ctx: RuntimeContext = load(args.runtime_ctx_path)
|
||||
|
||||
if runtime_ctx.bind_numa_node:
|
||||
@@ -371,9 +345,7 @@ class Driver(object):
|
||||
retval = run_executor(runtime_ctx, args.exec_id)
|
||||
elif args.mode == "scheduler":
|
||||
assert plan is not None
|
||||
jobmgr = JobManager(
|
||||
args.data_root, args.python_venv, args.task_image, args.platform
|
||||
)
|
||||
jobmgr = JobManager(args.data_root, args.python_venv, args.task_image, args.platform)
|
||||
exec_plan = jobmgr.run(
|
||||
plan,
|
||||
job_id=args.job_id,
|
||||
|
||||
@@ -35,9 +35,7 @@ class SimplePoolTask(object):
|
||||
def join(self, timeout=None):
|
||||
self.proc.join(timeout)
|
||||
if not self.ready() and timeout is not None:
|
||||
logger.warning(
|
||||
f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it"
|
||||
)
|
||||
logger.warning(f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it")
|
||||
self.terminate()
|
||||
self.proc.join()
|
||||
|
||||
@@ -45,17 +43,11 @@ class SimplePoolTask(object):
|
||||
return self.proc.pid and not self.proc.is_alive()
|
||||
|
||||
def exitcode(self):
|
||||
assert (
|
||||
self.ready()
|
||||
), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet"
|
||||
assert self.ready(), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet"
|
||||
if self.stopping:
|
||||
logger.info(
|
||||
f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}"
|
||||
)
|
||||
logger.info(f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}")
|
||||
elif self.proc.exitcode != 0:
|
||||
logger.error(
|
||||
f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}"
|
||||
)
|
||||
logger.error(f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}")
|
||||
return self.proc.exitcode
|
||||
|
||||
|
||||
@@ -79,9 +71,7 @@ class SimplePool(object):
|
||||
def update_queue(self):
|
||||
self.running_tasks = [t for t in self.running_tasks if not t.ready()]
|
||||
tasks_to_run = self.queued_tasks[: self.pool_size - len(self.running_tasks)]
|
||||
self.queued_tasks = self.queued_tasks[
|
||||
self.pool_size - len(self.running_tasks) :
|
||||
]
|
||||
self.queued_tasks = self.queued_tasks[self.pool_size - len(self.running_tasks) :]
|
||||
for task in tasks_to_run:
|
||||
task.start()
|
||||
self.running_tasks += tasks_to_run
|
||||
@@ -97,9 +87,7 @@ class Executor(object):
|
||||
The task executor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue
|
||||
) -> None:
|
||||
def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue) -> None:
|
||||
self.ctx = ctx
|
||||
self.id = id
|
||||
self.wq = wq
|
||||
@@ -135,9 +123,7 @@ class Executor(object):
|
||||
def process_work(item: WorkItem, cq: WorkQueue):
|
||||
item.exec(cq)
|
||||
cq.push(item)
|
||||
logger.info(
|
||||
f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs"
|
||||
)
|
||||
logger.info(f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs")
|
||||
logger.complete()
|
||||
|
||||
# for test
|
||||
@@ -146,16 +132,12 @@ class Executor(object):
|
||||
|
||||
# for test
|
||||
def skip_probes(self, epochs: int):
|
||||
self.wq.push(
|
||||
Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs)
|
||||
)
|
||||
self.wq.push(Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs))
|
||||
|
||||
@logger.catch(reraise=True, message="executor terminated unexpectedly")
|
||||
def run(self) -> bool:
|
||||
mp.current_process().name = "ExecutorMainProcess"
|
||||
logger.info(
|
||||
f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}"
|
||||
)
|
||||
logger.info(f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}")
|
||||
|
||||
with SimplePool(self.ctx.usable_cpu_count + 1) as pool:
|
||||
retval = self.exec_loop(pool)
|
||||
@@ -173,34 +155,20 @@ class Executor(object):
|
||||
try:
|
||||
items = self.wq.pop(count=self.ctx.usable_cpu_count)
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).critical(
|
||||
f"failed to pop from work queue: {self.wq}"
|
||||
)
|
||||
logger.opt(exception=ex).critical(f"failed to pop from work queue: {self.wq}")
|
||||
self.running = False
|
||||
items = []
|
||||
|
||||
if not items:
|
||||
secs_quiet_period = time.time() - latest_probe_time
|
||||
if (
|
||||
secs_quiet_period > self.ctx.secs_executor_probe_interval * 2
|
||||
and os.path.exists(self.ctx.job_status_path)
|
||||
):
|
||||
if secs_quiet_period > self.ctx.secs_executor_probe_interval * 2 and os.path.exists(self.ctx.job_status_path):
|
||||
with open(self.ctx.job_status_path) as status_file:
|
||||
if (
|
||||
status := status_file.read().strip()
|
||||
) and not status.startswith("running"):
|
||||
logger.critical(
|
||||
f"job scheduler already stopped: {status}, stopping executor"
|
||||
)
|
||||
if (status := status_file.read().strip()) and not status.startswith("running"):
|
||||
logger.critical(f"job scheduler already stopped: {status}, stopping executor")
|
||||
self.running = False
|
||||
break
|
||||
if (
|
||||
secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2
|
||||
and not pytest_running()
|
||||
):
|
||||
logger.critical(
|
||||
f"no probe received for {secs_quiet_period:.1f} secs, stopping executor"
|
||||
)
|
||||
if secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2 and not pytest_running():
|
||||
logger.critical(f"no probe received for {secs_quiet_period:.1f} secs, stopping executor")
|
||||
self.running = False
|
||||
break
|
||||
# no pending works, so wait a few seconds before checking results
|
||||
@@ -216,9 +184,7 @@ class Executor(object):
|
||||
if isinstance(item, StopWorkItem):
|
||||
running_work = self.running_works.get(item.work_to_stop, None)
|
||||
if running_work is None:
|
||||
logger.debug(
|
||||
f"cannot find {item.work_to_stop} in running works of {self.id}"
|
||||
)
|
||||
logger.debug(f"cannot find {item.work_to_stop} in running works of {self.id}")
|
||||
self.cq.push(item)
|
||||
else:
|
||||
logger.info(f"stopping work: {item.work_to_stop}")
|
||||
@@ -250,20 +216,14 @@ class Executor(object):
|
||||
self.collect_finished_works()
|
||||
time.sleep(self.ctx.secs_wq_poll_interval)
|
||||
item._local_gpu = granted_gpus
|
||||
logger.info(
|
||||
f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }"
|
||||
)
|
||||
logger.info(f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }")
|
||||
|
||||
# enqueue work item to the pool
|
||||
self.running_works[item.key] = (
|
||||
pool.apply_async(
|
||||
func=Executor.process_work, args=(item, self.cq), name=item.key
|
||||
),
|
||||
pool.apply_async(func=Executor.process_work, args=(item, self.cq), name=item.key),
|
||||
item,
|
||||
)
|
||||
logger.info(
|
||||
f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}..."
|
||||
)
|
||||
logger.info(f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}...")
|
||||
|
||||
# start to run works
|
||||
pool.update_queue()
|
||||
@@ -287,15 +247,11 @@ class Executor(object):
|
||||
work.join()
|
||||
if (exitcode := work.exitcode()) != 0:
|
||||
item.status = WorkStatus.CRASHED
|
||||
item.exception = NonzeroExitCode(
|
||||
f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}"
|
||||
)
|
||||
item.exception = NonzeroExitCode(f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}")
|
||||
try:
|
||||
self.cq.push(item)
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).critical(
|
||||
f"failed to push into completion queue: {self.cq}"
|
||||
)
|
||||
logger.opt(exception=ex).critical(f"failed to push into completion queue: {self.cq}")
|
||||
self.running = False
|
||||
finished_works.append(item)
|
||||
|
||||
@@ -304,9 +260,7 @@ class Executor(object):
|
||||
self.running_works.pop(item.key)
|
||||
if item._local_gpu:
|
||||
self.release_gpu(item._local_gpu)
|
||||
logger.info(
|
||||
f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }"
|
||||
)
|
||||
logger.info(f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }")
|
||||
|
||||
def acquire_gpu(self, quota: float) -> Dict[GPU, float]:
|
||||
"""
|
||||
@@ -336,6 +290,4 @@ class Executor(object):
|
||||
"""
|
||||
for gpu, quota in gpus.items():
|
||||
self.local_gpus[gpu] += quota
|
||||
assert (
|
||||
self.local_gpus[gpu] <= 1.0
|
||||
), f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}"
|
||||
assert self.local_gpus[gpu] <= 1.0, f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}"
|
||||
|
||||
@@ -26,13 +26,9 @@ class SchedStateExporter(Scheduler.StateObserver):
|
||||
if sched_state.large_runtime_state:
|
||||
logger.debug(f"pause exporting scheduler state")
|
||||
elif sched_state.num_local_running_works > 0:
|
||||
logger.debug(
|
||||
f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works"
|
||||
)
|
||||
logger.debug(f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works")
|
||||
else:
|
||||
dump(
|
||||
sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True
|
||||
)
|
||||
dump(sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True)
|
||||
sched_state.log_overall_progress()
|
||||
logger.debug(f"exported scheduler state to {self.sched_state_path}")
|
||||
|
||||
@@ -97,12 +93,8 @@ class JobManager(object):
|
||||
bind_numa_node=False,
|
||||
enforce_memory_limit=False,
|
||||
share_log_analytics: Optional[bool] = None,
|
||||
console_log_level: Literal[
|
||||
"CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"
|
||||
] = "INFO",
|
||||
file_log_level: Literal[
|
||||
"CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"
|
||||
] = "DEBUG",
|
||||
console_log_level: Literal["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] = "INFO",
|
||||
file_log_level: Literal["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] = "DEBUG",
|
||||
disable_log_rotation=False,
|
||||
sched_state_observers: Optional[List[Scheduler.StateObserver]] = None,
|
||||
output_path: Optional[str] = None,
|
||||
@@ -111,11 +103,7 @@ class JobManager(object):
|
||||
logger.info(f"using platform: {self.platform}")
|
||||
|
||||
job_id = JobId(hex=job_id or self.platform.default_job_id())
|
||||
job_time = (
|
||||
datetime.fromtimestamp(job_time)
|
||||
if job_time is not None
|
||||
else self.platform.default_job_time()
|
||||
)
|
||||
job_time = datetime.fromtimestamp(job_time) if job_time is not None else self.platform.default_job_time()
|
||||
|
||||
malloc_path = ""
|
||||
if memory_allocator == "system":
|
||||
@@ -131,19 +119,10 @@ class JobManager(object):
|
||||
arrow_default_malloc=memory_allocator,
|
||||
).splitlines()
|
||||
env_overrides = env_overrides + (env_variables or [])
|
||||
env_overrides = dict(
|
||||
tuple(kv.strip().split("=", maxsplit=1))
|
||||
for kv in filter(None, env_overrides)
|
||||
)
|
||||
env_overrides = dict(tuple(kv.strip().split("=", maxsplit=1)) for kv in filter(None, env_overrides))
|
||||
|
||||
share_log_analytics = (
|
||||
share_log_analytics
|
||||
if share_log_analytics is not None
|
||||
else self.platform.default_share_log_analytics()
|
||||
)
|
||||
shared_log_root = (
|
||||
self.platform.shared_log_root() if share_log_analytics else None
|
||||
)
|
||||
share_log_analytics = share_log_analytics if share_log_analytics is not None else self.platform.default_share_log_analytics()
|
||||
shared_log_root = self.platform.shared_log_root() if share_log_analytics else None
|
||||
|
||||
runtime_ctx = RuntimeContext(
|
||||
job_id,
|
||||
@@ -167,9 +146,7 @@ class JobManager(object):
|
||||
**kwargs,
|
||||
)
|
||||
runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
|
||||
logger.info(
|
||||
f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}"
|
||||
)
|
||||
logger.info(f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}")
|
||||
|
||||
dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True)
|
||||
logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}")
|
||||
@@ -178,13 +155,9 @@ class JobManager(object):
|
||||
logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}")
|
||||
|
||||
plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png")
|
||||
logger.info(
|
||||
f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png"
|
||||
)
|
||||
logger.info(f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png")
|
||||
|
||||
exec_plan = Planner(runtime_ctx).create_exec_plan(
|
||||
plan, manifest_only_final_results
|
||||
)
|
||||
exec_plan = Planner(runtime_ctx).create_exec_plan(plan, manifest_only_final_results)
|
||||
dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True)
|
||||
logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}")
|
||||
|
||||
@@ -229,9 +202,7 @@ class JobManager(object):
|
||||
sched_state_observers.insert(0, sched_state_exporter)
|
||||
|
||||
if os.path.exists(runtime_ctx.sched_state_path):
|
||||
logger.warning(
|
||||
f"loading scheduler state from: {runtime_ctx.sched_state_path}"
|
||||
)
|
||||
logger.warning(f"loading scheduler state from: {runtime_ctx.sched_state_path}")
|
||||
scheduler: Scheduler = load(runtime_ctx.sched_state_path)
|
||||
scheduler.sched_epoch += 1
|
||||
scheduler.sched_state_observers = sched_state_observers
|
||||
|
||||
@@ -54,9 +54,7 @@ class ExecutorState(Enum):
|
||||
|
||||
|
||||
class RemoteExecutor(object):
|
||||
def __init__(
|
||||
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0
|
||||
) -> None:
|
||||
def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0) -> None:
|
||||
self.ctx = ctx
|
||||
self.id = id
|
||||
self.wq = wq
|
||||
@@ -79,9 +77,7 @@ state={self.state}, probe={self.last_acked_probe}"
|
||||
return f"RemoteExecutor({self.id}):{self.state}"
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0
|
||||
) -> "RemoteExecutor":
|
||||
def create(ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0) -> "RemoteExecutor":
|
||||
wq = WorkQueueOnFilesystem(os.path.join(queue_dir, "wq"))
|
||||
cq = WorkQueueOnFilesystem(os.path.join(queue_dir, "cq"))
|
||||
return RemoteExecutor(ctx, id, wq, cq, init_epoch)
|
||||
@@ -173,9 +169,7 @@ state={self.state}, probe={self.last_acked_probe}"
|
||||
return self.cpu_count - self.cpu_count // 16
|
||||
|
||||
def add_running_work(self, item: WorkItem):
|
||||
assert (
|
||||
item.key not in self.running_works
|
||||
), f"duplicate work item assigned to {repr(self)}: {item.key}"
|
||||
assert item.key not in self.running_works, f"duplicate work item assigned to {repr(self)}: {item.key}"
|
||||
self.running_works[item.key] = item
|
||||
self._allocated_cpus += item.cpu_limit
|
||||
self._allocated_gpus += item.gpu_limit
|
||||
@@ -219,9 +213,7 @@ state={self.state}, probe={self.last_acked_probe}"
|
||||
|
||||
def push(self, item: WorkItem, buffering=False) -> bool:
|
||||
if item.key in self.running_works:
|
||||
logger.warning(
|
||||
f"work item {item.key} already exists in running works of {self}"
|
||||
)
|
||||
logger.warning(f"work item {item.key} already exists in running works of {self}")
|
||||
return False
|
||||
item.start_time = time.time()
|
||||
item.exec_id = self.id
|
||||
@@ -250,9 +242,7 @@ state={self.state}, probe={self.last_acked_probe}"
|
||||
elif num_missed_probes > self.ctx.max_num_missed_probes:
|
||||
if self.state != ExecutorState.FAIL:
|
||||
self.state = ExecutorState.FAIL
|
||||
logger.error(
|
||||
f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}"
|
||||
)
|
||||
logger.error(f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}")
|
||||
return True
|
||||
elif self.state == ExecutorState.STOPPING:
|
||||
if self.stop_request_acked:
|
||||
@@ -277,9 +267,7 @@ state={self.state}, probe={self.last_acked_probe}"
|
||||
|
||||
|
||||
class LocalExecutor(RemoteExecutor):
|
||||
def __init__(
|
||||
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue
|
||||
) -> None:
|
||||
def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue) -> None:
|
||||
super().__init__(ctx, id, wq, cq)
|
||||
self.work = None
|
||||
self.running = False
|
||||
@@ -321,9 +309,7 @@ class LocalExecutor(RemoteExecutor):
|
||||
if item.gpu_limit > 0:
|
||||
assert len(local_gpus) > 0
|
||||
item._local_gpu = local_gpus[0]
|
||||
logger.info(
|
||||
f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}"
|
||||
)
|
||||
logger.info(f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}")
|
||||
|
||||
item = copy.copy(item)
|
||||
item.exec()
|
||||
@@ -368,9 +354,7 @@ class Scheduler(object):
|
||||
self.callback = callback
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
repr(self.callback) if self.callback is not None else super().__repr__()
|
||||
)
|
||||
return repr(self.callback) if self.callback is not None else super().__repr__()
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@@ -403,9 +387,7 @@ class Scheduler(object):
|
||||
self.stop_executor_on_failure = stop_executor_on_failure
|
||||
self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom
|
||||
self.remove_output_root = remove_output_root
|
||||
self.sched_state_observers: List[Scheduler.StateObserver] = (
|
||||
sched_state_observers or []
|
||||
)
|
||||
self.sched_state_observers: List[Scheduler.StateObserver] = sched_state_observers or []
|
||||
self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2
|
||||
# task states
|
||||
self.local_queue: List[Task] = []
|
||||
@@ -414,11 +396,7 @@ class Scheduler(object):
|
||||
self.scheduled_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
|
||||
self.finished_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
|
||||
self.succeeded_tasks: Dict[str, Task] = OrderedDict()
|
||||
self.nontrivial_tasks = dict(
|
||||
(key, task)
|
||||
for (key, task) in self.tasks.items()
|
||||
if not task.exec_on_scheduler
|
||||
)
|
||||
self.nontrivial_tasks = dict((key, task) for (key, task) in self.tasks.items() if not task.exec_on_scheduler)
|
||||
self.succeeded_nontrivial_tasks: Dict[str, Task] = OrderedDict()
|
||||
# executor pool
|
||||
self.local_executor = LocalExecutor.create(self.ctx, "localhost")
|
||||
@@ -463,18 +441,11 @@ class Scheduler(object):
|
||||
|
||||
@property
|
||||
def running_works(self) -> Iterable[WorkItem]:
|
||||
return (
|
||||
work
|
||||
for executor in (self.alive_executors + self.local_executors)
|
||||
for work in executor.running_works.values()
|
||||
)
|
||||
return (work for executor in (self.alive_executors + self.local_executors) for work in executor.running_works.values())
|
||||
|
||||
@property
|
||||
def num_running_works(self) -> int:
|
||||
return sum(
|
||||
len(executor.running_works)
|
||||
for executor in (self.alive_executors + self.local_executors)
|
||||
)
|
||||
return sum(len(executor.running_works) for executor in (self.alive_executors + self.local_executors))
|
||||
|
||||
@property
|
||||
def num_local_running_works(self) -> int:
|
||||
@@ -489,11 +460,7 @@ class Scheduler(object):
|
||||
|
||||
@property
|
||||
def pending_nontrivial_tasks(self) -> Dict[str, Task]:
|
||||
return dict(
|
||||
(key, task)
|
||||
for key, task in self.nontrivial_tasks.items()
|
||||
if key not in self.succeeded_nontrivial_tasks
|
||||
)
|
||||
return dict((key, task) for key, task in self.nontrivial_tasks.items() if key not in self.succeeded_nontrivial_tasks)
|
||||
|
||||
@property
|
||||
def num_pending_nontrivial_tasks(self) -> int:
|
||||
@@ -504,33 +471,20 @@ class Scheduler(object):
|
||||
|
||||
@property
|
||||
def succeeded_task_ids(self) -> Set[TaskRuntimeId]:
|
||||
return set(
|
||||
TaskRuntimeId(task.id, task.sched_epoch, task.retry_count)
|
||||
for task in self.succeeded_tasks.values()
|
||||
)
|
||||
return set(TaskRuntimeId(task.id, task.sched_epoch, task.retry_count) for task in self.succeeded_tasks.values())
|
||||
|
||||
@property
|
||||
def abandoned_tasks(self) -> List[Task]:
|
||||
succeeded_task_ids = self.succeeded_task_ids
|
||||
return [
|
||||
task
|
||||
for task in {**self.scheduled_tasks, **self.finished_tasks}.values()
|
||||
if task.runtime_id not in succeeded_task_ids
|
||||
]
|
||||
return [task for task in {**self.scheduled_tasks, **self.finished_tasks}.values() if task.runtime_id not in succeeded_task_ids]
|
||||
|
||||
@cached_property
|
||||
def remote_executors(self) -> List[RemoteExecutor]:
|
||||
return [
|
||||
executor
|
||||
for executor in self.available_executors.values()
|
||||
if not executor.local
|
||||
]
|
||||
return [executor for executor in self.available_executors.values() if not executor.local]
|
||||
|
||||
@cached_property
|
||||
def local_executors(self) -> List[RemoteExecutor]:
|
||||
return [
|
||||
executor for executor in self.available_executors.values() if executor.local
|
||||
]
|
||||
return [executor for executor in self.available_executors.values() if executor.local]
|
||||
|
||||
@cached_property
|
||||
def working_executors(self) -> List[RemoteExecutor]:
|
||||
@@ -592,10 +546,7 @@ class Scheduler(object):
|
||||
def start_speculative_execution(self):
|
||||
for executor in self.working_executors:
|
||||
for idx, item in enumerate(executor.running_works.values()):
|
||||
aggressive_retry = (
|
||||
self.aggressive_speculative_exec
|
||||
and len(self.good_executors) >= self.ctx.num_executors
|
||||
)
|
||||
aggressive_retry = self.aggressive_speculative_exec and len(self.good_executors) >= self.ctx.num_executors
|
||||
short_sched_queue = len(self.sched_queue) < len(self.good_executors)
|
||||
if (
|
||||
isinstance(item, Task)
|
||||
@@ -603,8 +554,7 @@ class Scheduler(object):
|
||||
and item.allow_speculative_exec
|
||||
and item.retry_count < self.max_retry_count
|
||||
and item.retry_count == self.tasks[item.key].retry_count
|
||||
and (logical_node := self.logical_nodes.get(item.node_id, None))
|
||||
is not None
|
||||
and (logical_node := self.logical_nodes.get(item.node_id, None)) is not None
|
||||
):
|
||||
perf_stats = logical_node.get_perf_stats("elapsed wall time (secs)")
|
||||
if perf_stats is not None and perf_stats.cnt >= 20:
|
||||
@@ -639,12 +589,8 @@ class Scheduler(object):
|
||||
if entry.is_dir():
|
||||
_, exec_id = os.path.split(entry.path)
|
||||
if exec_id not in self.available_executors:
|
||||
self.available_executors[exec_id] = RemoteExecutor.create(
|
||||
self.ctx, exec_id, entry.path, self.probe_epoch
|
||||
)
|
||||
logger.info(
|
||||
f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}"
|
||||
)
|
||||
self.available_executors[exec_id] = RemoteExecutor.create(self.ctx, exec_id, entry.path, self.probe_epoch)
|
||||
logger.info(f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}")
|
||||
self.clear_cached_executor_lists()
|
||||
# start a new probe epoch
|
||||
self.last_executor_probe_time = time.time()
|
||||
@@ -668,9 +614,7 @@ class Scheduler(object):
|
||||
item.status = WorkStatus.EXEC_FAILED
|
||||
item.finish_time = time.time()
|
||||
if isinstance(item, Task) and item.key not in self.succeeded_tasks:
|
||||
logger.warning(
|
||||
f"reschedule {repr(item)} on failed executor: {repr(executor)}"
|
||||
)
|
||||
logger.warning(f"reschedule {repr(item)} on failed executor: {repr(executor)}")
|
||||
self.try_enqueue(self.get_retry_task(item.key))
|
||||
|
||||
if any(executor_state_changed):
|
||||
@@ -690,9 +634,7 @@ class Scheduler(object):
|
||||
# remove the reference to input deps
|
||||
task.input_deps = {dep_key: None for dep_key in task.input_deps}
|
||||
# feed input datasets
|
||||
task.input_datasets = [
|
||||
self.succeeded_tasks[dep_key].output for dep_key in task.input_deps
|
||||
]
|
||||
task.input_datasets = [self.succeeded_tasks[dep_key].output for dep_key in task.input_deps]
|
||||
task.sched_epoch = self.sched_epoch
|
||||
return task
|
||||
|
||||
@@ -713,9 +655,7 @@ class Scheduler(object):
|
||||
task.dataset = finished_task.dataset
|
||||
|
||||
def get_runnable_tasks(self, finished_task: Task) -> Iterable[Task]:
|
||||
assert (
|
||||
finished_task.status == WorkStatus.SUCCEED
|
||||
), f"task not succeeded: {finished_task}"
|
||||
assert finished_task.status == WorkStatus.SUCCEED, f"task not succeeded: {finished_task}"
|
||||
for output_key in finished_task.output_deps:
|
||||
output_dep = self.tasks[output_key]
|
||||
if all(key in self.succeeded_tasks for key in output_dep.input_deps):
|
||||
@@ -730,14 +670,8 @@ class Scheduler(object):
|
||||
for executor in self.remote_executors:
|
||||
running_task = executor.running_works.get(task_key, None)
|
||||
if running_task is not None:
|
||||
logger.info(
|
||||
f"try to stop {repr(running_task)} running on {repr(executor)}"
|
||||
)
|
||||
executor.wq.push(
|
||||
StopWorkItem(
|
||||
f".StopWorkItem-{repr(running_task)}", running_task.key
|
||||
)
|
||||
)
|
||||
logger.info(f"try to stop {repr(running_task)} running on {repr(executor)}")
|
||||
executor.wq.push(StopWorkItem(f".StopWorkItem-{repr(running_task)}", running_task.key))
|
||||
|
||||
def try_relax_memory_limit(self, task: Task, executor: RemoteExecutor) -> bool:
|
||||
if task.memory_limit >= executor.memory_size:
|
||||
@@ -745,9 +679,7 @@ class Scheduler(object):
|
||||
return False
|
||||
relaxed_memory_limit = min(executor.memory_size, task.memory_limit * 2)
|
||||
task._memory_boost = relaxed_memory_limit / task._memory_limit
|
||||
logger.warning(
|
||||
f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ..."
|
||||
)
|
||||
logger.warning(f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ...")
|
||||
return True
|
||||
|
||||
def try_boost_resource(self, item: WorkItem, executor: RemoteExecutor):
|
||||
@@ -777,9 +709,7 @@ class Scheduler(object):
|
||||
if item._cpu_limit < boost_cpu or item._memory_limit < boost_mem:
|
||||
item._cpu_boost = boost_cpu / item._cpu_limit
|
||||
item._memory_boost = boost_mem / item._memory_limit
|
||||
logger.info(
|
||||
f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB"
|
||||
)
|
||||
logger.info(f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB")
|
||||
|
||||
def get_retry_task(self, key: str) -> Task:
|
||||
task = self.tasks[key]
|
||||
@@ -794,9 +724,7 @@ class Scheduler(object):
|
||||
remove_path(self.ctx.staging_root)
|
||||
|
||||
if abandoned_tasks := self.abandoned_tasks:
|
||||
logger.info(
|
||||
f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..."
|
||||
)
|
||||
logger.info(f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ...")
|
||||
assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks))
|
||||
|
||||
@logger.catch(reraise=pytest_running(), message="failed to export task metrics")
|
||||
@@ -825,15 +753,9 @@ class Scheduler(object):
|
||||
buffering=32 * MB,
|
||||
)
|
||||
|
||||
task_props = arrow.array(
|
||||
pristine_attrs_dict(task) for task in self.finished_tasks.values()
|
||||
)
|
||||
partition_infos = arrow.array(
|
||||
task.partition_infos_as_dict for task in self.finished_tasks.values()
|
||||
)
|
||||
perf_metrics = arrow.array(
|
||||
dict(task.perf_metrics) for task in self.finished_tasks.values()
|
||||
)
|
||||
task_props = arrow.array(pristine_attrs_dict(task) for task in self.finished_tasks.values())
|
||||
partition_infos = arrow.array(task.partition_infos_as_dict for task in self.finished_tasks.values())
|
||||
perf_metrics = arrow.array(dict(task.perf_metrics) for task in self.finished_tasks.values())
|
||||
task_metrics = arrow.Table.from_arrays(
|
||||
[task_props, partition_infos, perf_metrics],
|
||||
names=["task_props", "partition_infos", "perf_metrics"],
|
||||
@@ -862,12 +784,7 @@ class Scheduler(object):
|
||||
[
|
||||
dict(
|
||||
task=repr(task),
|
||||
node=(
|
||||
repr(node)
|
||||
if (node := self.logical_nodes.get(task.node_id, None))
|
||||
is not None
|
||||
else "StandaloneTasks"
|
||||
),
|
||||
node=(repr(node) if (node := self.logical_nodes.get(task.node_id, None)) is not None else "StandaloneTasks"),
|
||||
status=str(task.status),
|
||||
executor=task.exec_id,
|
||||
start_time=datetime.fromtimestamp(task.start_time),
|
||||
@@ -925,23 +842,16 @@ class Scheduler(object):
|
||||
fig_filename, _ = fig_title.split(" - ", maxsplit=1)
|
||||
fig_filename += ".html"
|
||||
fig_path = os.path.join(self.ctx.log_root, fig_filename)
|
||||
fig.update_yaxes(
|
||||
autorange="reversed"
|
||||
) # otherwise tasks are listed from the bottom up
|
||||
fig.update_yaxes(autorange="reversed") # otherwise tasks are listed from the bottom up
|
||||
fig.update_traces(marker_line_color="black", marker_line_width=1, opacity=1)
|
||||
fig.write_html(
|
||||
fig_path, include_plotlyjs="cdn" if pytest_running() else True
|
||||
)
|
||||
fig.write_html(fig_path, include_plotlyjs="cdn" if pytest_running() else True)
|
||||
if self.ctx.shared_log_root:
|
||||
shutil.copy(fig_path, self.ctx.shared_log_root)
|
||||
logger.debug(f"exported timeline figure to {fig_path}")
|
||||
|
||||
def notify_state_observers(self, force_notify=False) -> bool:
|
||||
secs_since_last_state_notify = time.time() - self.last_state_notify_time
|
||||
if (
|
||||
force_notify
|
||||
or secs_since_last_state_notify >= self.secs_state_notify_interval
|
||||
):
|
||||
if force_notify or secs_since_last_state_notify >= self.secs_state_notify_interval:
|
||||
self.last_state_notify_time = time.time()
|
||||
for observer in self.sched_state_observers:
|
||||
if force_notify or observer.enabled:
|
||||
@@ -949,14 +859,10 @@ class Scheduler(object):
|
||||
observer.update(self)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= self.ctx.secs_executor_probe_interval / 2:
|
||||
self.secs_state_notify_interval = (
|
||||
self.ctx.secs_executor_probe_timeout
|
||||
)
|
||||
self.secs_state_notify_interval = self.ctx.secs_executor_probe_timeout
|
||||
if elapsed_time >= self.ctx.secs_executor_probe_interval:
|
||||
observer.enabled = False
|
||||
logger.warning(
|
||||
f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}"
|
||||
)
|
||||
logger.warning(f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -984,9 +890,7 @@ class Scheduler(object):
|
||||
|
||||
def run(self) -> bool:
|
||||
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
|
||||
logger.info(
|
||||
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"
|
||||
)
|
||||
logger.info(f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}")
|
||||
|
||||
perf_profile = None
|
||||
if self.ctx.enable_profiling:
|
||||
@@ -1001,48 +905,30 @@ class Scheduler(object):
|
||||
self.prioritize_retry |= self.sched_epoch > 0
|
||||
|
||||
if self.local_queue or self.sched_queue:
|
||||
pending_tasks = [
|
||||
item
|
||||
for item in self.local_queue + self.sched_queue
|
||||
if isinstance(item, Task)
|
||||
]
|
||||
pending_tasks = [item for item in self.local_queue + self.sched_queue if isinstance(item, Task)]
|
||||
self.local_queue.clear()
|
||||
self.sched_queue.clear()
|
||||
logger.info(
|
||||
f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..."
|
||||
)
|
||||
logger.info(f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ...")
|
||||
self.try_enqueue(pending_tasks)
|
||||
|
||||
if self.sched_epoch == 0:
|
||||
leaf_tasks = self.exec_plan.leaves
|
||||
logger.info(
|
||||
f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..."
|
||||
)
|
||||
logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...")
|
||||
self.try_enqueue(leaf_tasks)
|
||||
|
||||
self.log_overall_progress()
|
||||
while (num_finished_tasks := self.process_finished_tasks(pool)) > 0:
|
||||
logger.info(
|
||||
f"processed {num_finished_tasks} finished tasks during startup"
|
||||
)
|
||||
logger.info(f"processed {num_finished_tasks} finished tasks during startup")
|
||||
self.log_overall_progress()
|
||||
|
||||
earlier_running_tasks = [
|
||||
item for item in self.running_works if isinstance(item, Task)
|
||||
]
|
||||
earlier_running_tasks = [item for item in self.running_works if isinstance(item, Task)]
|
||||
if earlier_running_tasks:
|
||||
logger.info(
|
||||
f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..."
|
||||
)
|
||||
logger.info(f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ...")
|
||||
self.try_enqueue(earlier_running_tasks)
|
||||
|
||||
self.suspend_good_executors()
|
||||
self.add_state_observer(
|
||||
Scheduler.StateObserver(Scheduler.log_current_status)
|
||||
)
|
||||
self.add_state_observer(
|
||||
Scheduler.StateObserver(Scheduler.export_timeline_figs)
|
||||
)
|
||||
self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status))
|
||||
self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs))
|
||||
self.notify_state_observers(force_notify=True)
|
||||
|
||||
try:
|
||||
@@ -1063,14 +949,10 @@ class Scheduler(object):
|
||||
if self.success:
|
||||
self.clean_temp_files(pool)
|
||||
logger.success(f"final output path: {self.exec_plan.final_output_path}")
|
||||
logger.info(
|
||||
f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}"
|
||||
)
|
||||
logger.info(f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}")
|
||||
|
||||
if perf_profile is not None:
|
||||
logger.debug(
|
||||
f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}"
|
||||
)
|
||||
logger.debug(f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}")
|
||||
|
||||
logger.info(f"scheduler of job {self.ctx.job_id} exits")
|
||||
logger.complete()
|
||||
@@ -1082,20 +964,14 @@ class Scheduler(object):
|
||||
task = self.copy_task_for_execution(task)
|
||||
if task.key in self.succeeded_tasks:
|
||||
logger.debug(f"task {repr(task)} already succeeded, skipping")
|
||||
self.try_enqueue(
|
||||
self.get_runnable_tasks(self.succeeded_tasks[task.key])
|
||||
)
|
||||
self.try_enqueue(self.get_runnable_tasks(self.succeeded_tasks[task.key]))
|
||||
continue
|
||||
if task.runtime_id in self.scheduled_tasks:
|
||||
logger.debug(f"task {repr(task)} already scheduled, skipping")
|
||||
continue
|
||||
# save enqueued task
|
||||
self.scheduled_tasks[task.runtime_id] = task
|
||||
if (
|
||||
self.standalone_mode
|
||||
or task.exec_on_scheduler
|
||||
or task.skip_when_any_input_empty
|
||||
):
|
||||
if self.standalone_mode or task.exec_on_scheduler or task.skip_when_any_input_empty:
|
||||
self.local_queue.append(task)
|
||||
else:
|
||||
self.sched_queue.append(task)
|
||||
@@ -1114,34 +990,20 @@ class Scheduler(object):
|
||||
|
||||
if self.local_queue:
|
||||
assert self.local_executor.alive
|
||||
logger.info(
|
||||
f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ..."
|
||||
)
|
||||
self.local_queue = [
|
||||
item
|
||||
for item in self.local_queue
|
||||
if not self.local_executor.push(item, buffering=True)
|
||||
]
|
||||
logger.info(f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ...")
|
||||
self.local_queue = [item for item in self.local_queue if not self.local_executor.push(item, buffering=True)]
|
||||
self.local_executor.flush()
|
||||
|
||||
has_progress |= self.dispatch_tasks(pool) > 0
|
||||
|
||||
if len(
|
||||
self.sched_queue
|
||||
) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors):
|
||||
if len(self.sched_queue) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors):
|
||||
for executor in self.good_executors:
|
||||
if executor.idle:
|
||||
logger.info(
|
||||
f"{len(self.good_executors)} remote executors running, stopping {executor}"
|
||||
)
|
||||
logger.info(f"{len(self.good_executors)} remote executors running, stopping {executor}")
|
||||
executor.stop()
|
||||
break
|
||||
|
||||
if (
|
||||
len(self.sched_queue) == 0
|
||||
and len(self.local_queue) == 0
|
||||
and self.num_running_works == 0
|
||||
):
|
||||
if len(self.sched_queue) == 0 and len(self.local_queue) == 0 and self.num_running_works == 0:
|
||||
self.log_overall_progress()
|
||||
assert (
|
||||
self.num_pending_tasks == 0
|
||||
@@ -1166,29 +1028,13 @@ class Scheduler(object):
|
||||
|
||||
def dispatch_tasks(self, pool: ThreadPoolExecutor):
|
||||
# sort pending tasks
|
||||
item_sort_key = (
|
||||
(lambda item: (-item.retry_count, item.id))
|
||||
if self.prioritize_retry
|
||||
else (lambda item: (item.retry_count, item.id))
|
||||
)
|
||||
item_sort_key = (lambda item: (-item.retry_count, item.id)) if self.prioritize_retry else (lambda item: (item.retry_count, item.id))
|
||||
items_sorted_by_node_id = sorted(self.sched_queue, key=lambda t: t.node_id)
|
||||
items_group_by_node_id = itertools.groupby(
|
||||
items_sorted_by_node_id, key=lambda t: t.node_id
|
||||
)
|
||||
sorted_item_groups = [
|
||||
sorted(items, key=item_sort_key) for _, items in items_group_by_node_id
|
||||
]
|
||||
self.sched_queue = [
|
||||
item
|
||||
for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None)
|
||||
for item in batch
|
||||
if item is not None
|
||||
]
|
||||
items_group_by_node_id = itertools.groupby(items_sorted_by_node_id, key=lambda t: t.node_id)
|
||||
sorted_item_groups = [sorted(items, key=item_sort_key) for _, items in items_group_by_node_id]
|
||||
self.sched_queue = [item for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None) for item in batch if item is not None]
|
||||
|
||||
final_phase = (
|
||||
self.num_pending_nontrivial_tasks - self.num_running_works
|
||||
<= len(self.good_executors) * 2
|
||||
)
|
||||
final_phase = self.num_pending_nontrivial_tasks - self.num_running_works <= len(self.good_executors) * 2
|
||||
num_dispatched_tasks = 0
|
||||
unassigned_tasks = []
|
||||
|
||||
@@ -1196,42 +1042,31 @@ class Scheduler(object):
|
||||
first_item = self.sched_queue[0]
|
||||
|
||||
# assign tasks to executors in round-robin fashion
|
||||
usable_executors = [
|
||||
executor for executor in self.good_executors if not executor.busy
|
||||
]
|
||||
for executor in sorted(
|
||||
usable_executors, key=lambda exec: len(exec.running_works)
|
||||
):
|
||||
usable_executors = [executor for executor in self.good_executors if not executor.busy]
|
||||
for executor in sorted(usable_executors, key=lambda exec: len(exec.running_works)):
|
||||
if not self.sched_queue:
|
||||
break
|
||||
item = self.sched_queue[0]
|
||||
|
||||
if item._memory_limit is None:
|
||||
item._memory_limit = np.int64(
|
||||
executor.memory_size * item._cpu_limit // executor.cpu_count
|
||||
)
|
||||
item._memory_limit = np.int64(executor.memory_size * item._cpu_limit // executor.cpu_count)
|
||||
|
||||
if item.key in self.succeeded_tasks:
|
||||
logger.debug(f"task {repr(item)} already succeeded, skipping")
|
||||
self.sched_queue.pop(0)
|
||||
self.try_enqueue(
|
||||
self.get_runnable_tasks(self.succeeded_tasks[item.key])
|
||||
)
|
||||
self.try_enqueue(self.get_runnable_tasks(self.succeeded_tasks[item.key]))
|
||||
elif (
|
||||
len(executor.running_works) < executor.max_running_works
|
||||
and executor.allocated_cpus + item.cpu_limit <= executor.cpu_count
|
||||
and executor.allocated_gpus + item.gpu_limit <= executor.gpu_count
|
||||
and executor.allocated_memory + item.memory_limit
|
||||
<= executor.memory_size
|
||||
and executor.allocated_memory + item.memory_limit <= executor.memory_size
|
||||
and item.key not in executor.running_works
|
||||
):
|
||||
if final_phase:
|
||||
self.try_boost_resource(item, executor)
|
||||
# push to wq of executor but not flushed yet
|
||||
executor.push(item, buffering=True)
|
||||
logger.info(
|
||||
f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}"
|
||||
)
|
||||
logger.info(f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}")
|
||||
self.sched_queue.pop(0)
|
||||
num_dispatched_tasks += 1
|
||||
|
||||
@@ -1242,55 +1077,35 @@ class Scheduler(object):
|
||||
self.sched_queue.extend(unassigned_tasks)
|
||||
|
||||
# flush the buffered work items into wq
|
||||
assert all(
|
||||
pool.map(RemoteExecutor.flush, self.good_executors)
|
||||
), f"failed to flush work queues"
|
||||
assert all(pool.map(RemoteExecutor.flush, self.good_executors)), f"failed to flush work queues"
|
||||
return num_dispatched_tasks
|
||||
|
||||
def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int:
|
||||
pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values())
|
||||
num_finished_tasks = 0
|
||||
|
||||
for executor, finished_tasks in zip(
|
||||
self.available_executors.values(), pop_results
|
||||
):
|
||||
for executor, finished_tasks in zip(self.available_executors.values(), pop_results):
|
||||
|
||||
for finished_task in finished_tasks:
|
||||
assert isinstance(finished_task, Task)
|
||||
|
||||
scheduled_task = self.scheduled_tasks.get(
|
||||
finished_task.runtime_id, None
|
||||
)
|
||||
scheduled_task = self.scheduled_tasks.get(finished_task.runtime_id, None)
|
||||
if scheduled_task is None:
|
||||
logger.info(
|
||||
f"task not initiated by current scheduler: {finished_task}"
|
||||
)
|
||||
logger.info(f"task not initiated by current scheduler: {finished_task}")
|
||||
if finished_task.status != WorkStatus.SUCCEED and (
|
||||
missing_inputs := [
|
||||
key
|
||||
for key in finished_task.input_deps
|
||||
if key not in self.succeeded_tasks
|
||||
]
|
||||
missing_inputs := [key for key in finished_task.input_deps if key not in self.succeeded_tasks]
|
||||
):
|
||||
logger.info(
|
||||
f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}"
|
||||
)
|
||||
logger.info(f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}")
|
||||
continue
|
||||
|
||||
if finished_task.status == WorkStatus.INCOMPLETE:
|
||||
logger.trace(
|
||||
f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}"
|
||||
)
|
||||
self.tasks[finished_task.key].runtime_state = (
|
||||
finished_task.runtime_state
|
||||
)
|
||||
logger.trace(f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}")
|
||||
self.tasks[finished_task.key].runtime_state = finished_task.runtime_state
|
||||
continue
|
||||
|
||||
prior_task = self.finished_tasks.get(finished_task.runtime_id, None)
|
||||
if prior_task is not None:
|
||||
logger.info(
|
||||
f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}"
|
||||
)
|
||||
logger.info(f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}")
|
||||
continue
|
||||
else:
|
||||
self.finished_tasks[finished_task.runtime_id] = finished_task
|
||||
@@ -1298,30 +1113,22 @@ class Scheduler(object):
|
||||
|
||||
succeeded_task = self.succeeded_tasks.get(finished_task.key, None)
|
||||
if succeeded_task is not None:
|
||||
logger.info(
|
||||
f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}"
|
||||
)
|
||||
logger.info(f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}")
|
||||
continue
|
||||
|
||||
if finished_task.status in (WorkStatus.FAILED, WorkStatus.CRASHED):
|
||||
logger.warning(
|
||||
f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}"
|
||||
)
|
||||
logger.warning(f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}")
|
||||
finished_task.dump()
|
||||
|
||||
task = self.tasks[finished_task.key]
|
||||
task.fail_count += 1
|
||||
|
||||
if task.fail_count > self.max_fail_count:
|
||||
logger.critical(
|
||||
f"task failed too many times: {finished_task}, stopping ..."
|
||||
)
|
||||
logger.critical(f"task failed too many times: {finished_task}, stopping ...")
|
||||
self.stop_executors()
|
||||
self.sched_running = False
|
||||
|
||||
if not executor.local and finished_task.oom(
|
||||
self.nonzero_exitcode_as_oom
|
||||
):
|
||||
if not executor.local and finished_task.oom(self.nonzero_exitcode_as_oom):
|
||||
if task._memory_limit is None:
|
||||
task._memory_limit = finished_task._memory_limit
|
||||
self.try_relax_memory_limit(task, executor)
|
||||
@@ -1332,9 +1139,7 @@ class Scheduler(object):
|
||||
|
||||
self.try_enqueue(self.get_retry_task(finished_task.key))
|
||||
else:
|
||||
assert (
|
||||
finished_task.status == WorkStatus.SUCCEED
|
||||
), f"unexpected task status: {finished_task}"
|
||||
assert finished_task.status == WorkStatus.SUCCEED, f"unexpected task status: {finished_task}"
|
||||
logger.log(
|
||||
"TRACE" if finished_task.exec_on_scheduler else "INFO",
|
||||
"task succeeded on {}: {}",
|
||||
@@ -1344,9 +1149,7 @@ class Scheduler(object):
|
||||
|
||||
self.succeeded_tasks[finished_task.key] = finished_task
|
||||
if not finished_task.exec_on_scheduler:
|
||||
self.succeeded_nontrivial_tasks[finished_task.key] = (
|
||||
finished_task
|
||||
)
|
||||
self.succeeded_nontrivial_tasks[finished_task.key] = finished_task
|
||||
|
||||
# stop the redundant retries of finished task
|
||||
self.stop_running_tasks(finished_task.key)
|
||||
@@ -1356,9 +1159,7 @@ class Scheduler(object):
|
||||
if finished_task.id == self.exec_plan.root_task.id:
|
||||
self.sched_queue = []
|
||||
self.stop_executors()
|
||||
logger.success(
|
||||
f"all tasks completed, root task: {finished_task}"
|
||||
)
|
||||
logger.success(f"all tasks completed, root task: {finished_task}")
|
||||
logger.success(
|
||||
f"{len(self.succeeded_tasks)} succeeded tasks, success: {self.success}, elapsed time: {self.elapsed_time:.3f} secs"
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -55,9 +55,7 @@ class WorkItem(object):
|
||||
) -> None:
|
||||
self._cpu_limit = cpu_limit
|
||||
self._gpu_limit = gpu_limit
|
||||
self._memory_limit = (
|
||||
np.int64(memory_limit) if memory_limit is not None else None
|
||||
)
|
||||
self._memory_limit = np.int64(memory_limit) if memory_limit is not None else None
|
||||
self._cpu_boost = 1
|
||||
self._memory_boost = 1
|
||||
self._numa_node = None
|
||||
@@ -88,11 +86,7 @@ class WorkItem(object):
|
||||
|
||||
@property
|
||||
def memory_limit(self) -> np.int64:
|
||||
return (
|
||||
np.int64(self._memory_boost * self._memory_limit)
|
||||
if self._memory_limit
|
||||
else 0
|
||||
)
|
||||
return np.int64(self._memory_boost * self._memory_limit) if self._memory_limit else 0
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
@@ -142,13 +136,7 @@ class WorkItem(object):
|
||||
return (
|
||||
self._memory_limit is not None
|
||||
and self.status == WorkStatus.CRASHED
|
||||
and (
|
||||
isinstance(self.exception, (OutOfMemory, MemoryError))
|
||||
or (
|
||||
isinstance(self.exception, NonzeroExitCode)
|
||||
and nonzero_exitcode_as_oom
|
||||
)
|
||||
)
|
||||
and (isinstance(self.exception, (OutOfMemory, MemoryError)) or (isinstance(self.exception, NonzeroExitCode) and nonzero_exitcode_as_oom))
|
||||
)
|
||||
|
||||
def run(self) -> bool:
|
||||
@@ -175,9 +163,7 @@ class WorkItem(object):
|
||||
else:
|
||||
self.status = WorkStatus.FAILED
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).error(
|
||||
f"{repr(self)} crashed with error. node location at {self.location}"
|
||||
)
|
||||
logger.opt(exception=ex).error(f"{repr(self)} crashed with error. node location at {self.location}")
|
||||
self.status = WorkStatus.CRASHED
|
||||
self.exception = ex
|
||||
finally:
|
||||
@@ -204,25 +190,18 @@ class WorkBatch(WorkItem):
|
||||
cpu_limit = max(w.cpu_limit for w in works)
|
||||
gpu_limit = max(w.gpu_limit for w in works)
|
||||
memory_limit = max(w.memory_limit for w in works)
|
||||
super().__init__(
|
||||
f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit
|
||||
)
|
||||
super().__init__(f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit)
|
||||
self.works = works
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
super().__str__()
|
||||
+ f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}"
|
||||
)
|
||||
return super().__str__() + f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}"
|
||||
|
||||
def run(self) -> bool:
|
||||
logger.info(f"processing {len(self.works)} works in the batch")
|
||||
for index, work in enumerate(self.works):
|
||||
work.exec_id = self.exec_id
|
||||
if work.exec(self.exec_cq) != WorkStatus.SUCCEED:
|
||||
logger.error(
|
||||
f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}"
|
||||
)
|
||||
logger.error(f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}")
|
||||
return False
|
||||
logger.info(f"done {len(self.works)} works in the batch")
|
||||
return True
|
||||
@@ -375,9 +354,7 @@ class WorkQueueOnFilesystem(WorkQueue):
|
||||
os.rename(tempfile_path, enqueued_path)
|
||||
return True
|
||||
except OSError as err:
|
||||
logger.critical(
|
||||
f"failed to rename {tempfile_path} to {enqueued_path}: {err}"
|
||||
)
|
||||
logger.critical(f"failed to rename {tempfile_path} to {enqueued_path}: {err}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -405,27 +382,17 @@ def count_objects(obj, object_cnt=None, visited_objs=None, depth=0):
|
||||
object_cnt[class_name] = (cnt + 1, size + sys.getsizeof(obj))
|
||||
|
||||
key_attributes = ("__self__", "__dict__", "__slots__")
|
||||
if not isinstance(obj, (bool, str, int, float, type(None))) and any(
|
||||
attr_name in key_attributes for attr_name in dir(obj)
|
||||
):
|
||||
if not isinstance(obj, (bool, str, int, float, type(None))) and any(attr_name in key_attributes for attr_name in dir(obj)):
|
||||
logger.debug(f"{' ' * depth}{class_name}@{id(obj):x}")
|
||||
for attr_name in dir(obj):
|
||||
try:
|
||||
if (
|
||||
not attr_name.startswith("__") or attr_name in key_attributes
|
||||
) and not isinstance(
|
||||
if (not attr_name.startswith("__") or attr_name in key_attributes) and not isinstance(
|
||||
getattr(obj.__class__, attr_name, None), property
|
||||
):
|
||||
logger.debug(
|
||||
f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}"
|
||||
)
|
||||
count_objects(
|
||||
getattr(obj, attr_name), object_cnt, visited_objs, depth + 1
|
||||
)
|
||||
logger.debug(f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}")
|
||||
count_objects(getattr(obj, attr_name), object_cnt, visited_objs, depth + 1)
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
f"failed to get '{attr_name}' from {repr(obj)}: {ex}"
|
||||
)
|
||||
logger.warning(f"failed to get '{attr_name}' from {repr(obj)}: {ex}")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -433,23 +400,13 @@ def main():
|
||||
|
||||
from smallpond.execution.task import Probe
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="workqueue.py", description="Work Queue Reader"
|
||||
)
|
||||
parser = argparse.ArgumentParser(prog="workqueue.py", description="Work Queue Reader")
|
||||
parser.add_argument("wq_root", help="Work queue root path")
|
||||
parser.add_argument("-f", "--work_filter", default="", help="Work item filter")
|
||||
parser.add_argument(
|
||||
"-x", "--expand_batch", action="store_true", help="Expand batched works"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--count_object", action="store_true", help="Count number of objects"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--top_n_class", default=20, type=int, help="Show the top n classes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--log_level", default="INFO", help="Logging message level"
|
||||
)
|
||||
parser.add_argument("-x", "--expand_batch", action="store_true", help="Expand batched works")
|
||||
parser.add_argument("-c", "--count_object", action="store_true", help="Count number of objects")
|
||||
parser.add_argument("-n", "--top_n_class", default=20, type=int, help="Show the top n classes")
|
||||
parser.add_argument("-l", "--log_level", default="INFO", help="Logging message level")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.remove()
|
||||
@@ -468,9 +425,7 @@ def main():
|
||||
if args.count_object:
|
||||
object_cnt = {}
|
||||
count_objects(work, object_cnt)
|
||||
sorted_counts = sorted(
|
||||
[(v, k) for k, v in object_cnt.items()], reverse=True
|
||||
)
|
||||
sorted_counts = sorted([(v, k) for k, v in object_cnt.items()], reverse=True)
|
||||
for count, class_name in sorted_counts[: args.top_n_class]:
|
||||
logger.info(f" {class_name}: {count}")
|
||||
|
||||
|
||||
@@ -44,11 +44,7 @@ class RowRange:
|
||||
@property
|
||||
def estimated_data_size(self) -> int:
|
||||
"""The estimated uncompressed data size in bytes."""
|
||||
return (
|
||||
self.data_size * self.num_rows // self.file_num_rows
|
||||
if self.file_num_rows > 0
|
||||
else 0
|
||||
)
|
||||
return self.data_size * self.num_rows // self.file_num_rows if self.file_num_rows > 0 else 0
|
||||
|
||||
def take(self, num_rows: int) -> "RowRange":
|
||||
"""
|
||||
@@ -62,9 +58,7 @@ class RowRange:
|
||||
return head
|
||||
|
||||
@staticmethod
|
||||
def partition_by_rows(
|
||||
row_ranges: List["RowRange"], npartition: int
|
||||
) -> List[List["RowRange"]]:
|
||||
def partition_by_rows(row_ranges: List["RowRange"], npartition: int) -> List[List["RowRange"]]:
|
||||
"""Evenly split a list of row ranges into `npartition` partitions."""
|
||||
# NOTE: `row_ranges` should not be modified by this function
|
||||
row_ranges = copy.deepcopy(row_ranges)
|
||||
@@ -128,9 +122,7 @@ def convert_types_to_large_string(schema: arrow.Schema) -> arrow.Schema:
|
||||
new_fields = []
|
||||
for field in schema:
|
||||
new_type = convert_type_to_large(field.type)
|
||||
new_field = arrow.field(
|
||||
field.name, new_type, nullable=field.nullable, metadata=field.metadata
|
||||
)
|
||||
new_field = arrow.field(field.name, new_type, nullable=field.nullable, metadata=field.metadata)
|
||||
new_fields.append(new_field)
|
||||
return arrow.schema(new_fields, metadata=schema.metadata)
|
||||
|
||||
@@ -151,11 +143,7 @@ def filter_schema(
|
||||
if included_cols is not None:
|
||||
fields = [schema.field(col_name) for col_name in included_cols]
|
||||
if excluded_cols is not None:
|
||||
fields = [
|
||||
schema.field(col_name)
|
||||
for col_name in schema.names
|
||||
if col_name not in excluded_cols
|
||||
]
|
||||
fields = [schema.field(col_name) for col_name in schema.names if col_name not in excluded_cols]
|
||||
return arrow.schema(fields, metadata=schema.metadata)
|
||||
|
||||
|
||||
@@ -172,9 +160,7 @@ def _iter_record_batches(
|
||||
current_offset = 0
|
||||
required_l, required_r = offset, offset + length
|
||||
|
||||
for batch in file.iter_batches(
|
||||
batch_size=batch_size, columns=columns, use_threads=False
|
||||
):
|
||||
for batch in file.iter_batches(batch_size=batch_size, columns=columns, use_threads=False):
|
||||
current_l, current_r = current_offset, current_offset + batch.num_rows
|
||||
# check if intersection is null
|
||||
if current_r <= required_l:
|
||||
@@ -184,9 +170,7 @@ def _iter_record_batches(
|
||||
else:
|
||||
intersection_l = max(required_l, current_l)
|
||||
intersection_r = min(required_r, current_r)
|
||||
trimmed = batch.slice(
|
||||
intersection_l - current_offset, intersection_r - intersection_l
|
||||
)
|
||||
trimmed = batch.slice(intersection_l - current_offset, intersection_r - intersection_l)
|
||||
assert (
|
||||
trimmed.num_rows == intersection_r - intersection_l
|
||||
), f"trimmed.num_rows {trimmed.num_rows} != batch_length {intersection_r - intersection_l}"
|
||||
@@ -204,9 +188,7 @@ def build_batch_reader_from_files(
|
||||
) -> arrow.RecordBatchReader:
|
||||
assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list"
|
||||
schema = _read_schema_from_file(paths_or_ranges[0], columns, filesystem)
|
||||
iterator = _iter_record_batches_from_files(
|
||||
paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem
|
||||
)
|
||||
iterator = _iter_record_batches_from_files(paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem)
|
||||
return arrow.RecordBatchReader.from_batches(schema, iterator)
|
||||
|
||||
|
||||
@@ -216,9 +198,7 @@ def _read_schema_from_file(
|
||||
filesystem: fsspec.AbstractFileSystem = None,
|
||||
) -> arrow.Schema:
|
||||
path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
|
||||
schema = parquet.read_schema(
|
||||
filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem
|
||||
)
|
||||
schema = parquet.read_schema(filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem)
|
||||
if columns is not None:
|
||||
assert all(
|
||||
c in schema.names for c in columns
|
||||
@@ -253,9 +233,7 @@ def _iter_record_batches_from_files(
|
||||
yield from table.combine_chunks().to_batches(batch_size)
|
||||
|
||||
for path_or_range in paths_or_ranges:
|
||||
path = (
|
||||
path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
|
||||
)
|
||||
path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
|
||||
with parquet.ParquetFile(
|
||||
filesystem.unstrip_protocol(path) if filesystem else path,
|
||||
buffer_size=16 * MB,
|
||||
@@ -265,23 +243,16 @@ def _iter_record_batches_from_files(
|
||||
offset, length = path_or_range.begin, path_or_range.num_rows
|
||||
else:
|
||||
offset, length = 0, file.metadata.num_rows
|
||||
for batch in _iter_record_batches(
|
||||
file, columns, offset, length, batch_size
|
||||
):
|
||||
for batch in _iter_record_batches(file, columns, offset, length, batch_size):
|
||||
batch_size_exceeded = batch.num_rows + buffered_rows >= batch_size
|
||||
batch_byte_size_exceeded = (
|
||||
max_batch_byte_size is not None
|
||||
and batch.nbytes + buffered_bytes >= max_batch_byte_size
|
||||
)
|
||||
batch_byte_size_exceeded = max_batch_byte_size is not None and batch.nbytes + buffered_bytes >= max_batch_byte_size
|
||||
if not batch_size_exceeded and not batch_byte_size_exceeded:
|
||||
buffered_batches.append(batch)
|
||||
buffered_rows += batch.num_rows
|
||||
buffered_bytes += batch.nbytes
|
||||
else:
|
||||
if batch_size_exceeded:
|
||||
buffered_batches.append(
|
||||
batch.slice(0, batch_size - buffered_rows)
|
||||
)
|
||||
buffered_batches.append(batch.slice(0, batch_size - buffered_rows))
|
||||
batch = batch.slice(batch_size - buffered_rows)
|
||||
if buffered_batches:
|
||||
yield from combine_buffered_batches(buffered_batches)
|
||||
@@ -298,9 +269,7 @@ def read_parquet_files_into_table(
|
||||
columns: List[str] = None,
|
||||
filesystem: fsspec.AbstractFileSystem = None,
|
||||
) -> arrow.Table:
|
||||
batch_reader = build_batch_reader_from_files(
|
||||
paths_or_ranges, columns=columns, filesystem=filesystem
|
||||
)
|
||||
batch_reader = build_batch_reader_from_files(paths_or_ranges, columns=columns, filesystem=filesystem)
|
||||
return batch_reader.read_all()
|
||||
|
||||
|
||||
@@ -312,37 +281,22 @@ def load_from_parquet_files(
|
||||
) -> arrow.Table:
|
||||
start_time = time.time()
|
||||
assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list"
|
||||
paths = [
|
||||
path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
|
||||
for path_or_range in paths_or_ranges
|
||||
]
|
||||
paths = [path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range for path_or_range in paths_or_ranges]
|
||||
total_compressed_size = sum(
|
||||
(
|
||||
path_or_range.data_size
|
||||
if isinstance(path_or_range, RowRange)
|
||||
else os.path.getsize(path_or_range)
|
||||
)
|
||||
for path_or_range in paths_or_ranges
|
||||
)
|
||||
logger.debug(
|
||||
f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}..."
|
||||
(path_or_range.data_size if isinstance(path_or_range, RowRange) else os.path.getsize(path_or_range)) for path_or_range in paths_or_ranges
|
||||
)
|
||||
logger.debug(f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}...")
|
||||
num_workers = min(len(paths), max_workers)
|
||||
with ThreadPoolExecutor(num_workers) as pool:
|
||||
running_works = [
|
||||
pool.submit(read_parquet_files_into_table, batch, columns, filesystem)
|
||||
for batch in split_into_rows(paths_or_ranges, num_workers)
|
||||
pool.submit(read_parquet_files_into_table, batch, columns, filesystem) for batch in split_into_rows(paths_or_ranges, num_workers)
|
||||
]
|
||||
tables = [work.result() for work in running_works]
|
||||
logger.debug(
|
||||
f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)"
|
||||
)
|
||||
logger.debug(f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)")
|
||||
return arrow.concat_tables(tables)
|
||||
|
||||
|
||||
def parquet_write_table(
|
||||
table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args
|
||||
) -> int:
|
||||
def parquet_write_table(table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args) -> int:
|
||||
if filesystem is not None:
|
||||
return parquet.write_table(
|
||||
table,
|
||||
@@ -388,10 +342,7 @@ def dump_to_parquet_files(
|
||||
num_workers = min(len(batches), max_workers)
|
||||
num_tables = max(math.ceil(table.nbytes / MAX_PARQUET_FILE_BYTES), num_workers)
|
||||
logger.debug(f"evenly distributed {len(batches)} batches into {num_tables} files")
|
||||
tables = [
|
||||
arrow.Table.from_batches(batch, table.schema)
|
||||
for batch in split_into_rows(batches, num_tables)
|
||||
]
|
||||
tables = [arrow.Table.from_batches(batch, table.schema) for batch in split_into_rows(batches, num_tables)]
|
||||
assert sum(t.num_rows for t in tables) == table.num_rows
|
||||
|
||||
logger.debug(f"writing {len(tables)} files to {output_dir}")
|
||||
@@ -413,7 +364,5 @@ def dump_to_parquet_files(
|
||||
]
|
||||
assert all(work.result() or True for work in running_works)
|
||||
|
||||
logger.debug(
|
||||
f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)"
|
||||
)
|
||||
logger.debug(f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)")
|
||||
return True
|
||||
|
||||
@@ -44,9 +44,7 @@ def remove_path(path: str):
|
||||
os.symlink(realpath, link)
|
||||
return
|
||||
except Exception as ex:
|
||||
logger.opt(exception=ex).debug(
|
||||
f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')"
|
||||
)
|
||||
logger.opt(exception=ex).debug(f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')")
|
||||
shutil.rmtree(realpath, ignore_errors=True)
|
||||
|
||||
|
||||
@@ -94,9 +92,7 @@ def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int:
|
||||
raise
|
||||
except Exception as ex:
|
||||
trace_str, trace_err = get_pickle_trace(obj)
|
||||
logger.opt(exception=ex).error(
|
||||
f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}"
|
||||
)
|
||||
logger.opt(exception=ex).error(f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}")
|
||||
if trace_err is None:
|
||||
raise
|
||||
else:
|
||||
@@ -107,9 +103,7 @@ def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int:
|
||||
|
||||
if atomic_write:
|
||||
directory, filename = os.path.split(path)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"wb", buffering=buffering, dir=directory, prefix=filename, delete=False
|
||||
) as fout:
|
||||
with tempfile.NamedTemporaryFile("wb", buffering=buffering, dir=directory, prefix=filename, delete=False) as fout:
|
||||
write_to_file(fout)
|
||||
fout.seek(0, os.SEEK_END)
|
||||
size = fout.tell()
|
||||
|
||||
@@ -206,9 +206,7 @@ class DataSet(object):
|
||||
resolved_paths.append(path)
|
||||
if wildcard_paths:
|
||||
if len(wildcard_paths) == 1:
|
||||
expanded_paths = glob.glob(
|
||||
wildcard_paths[0], recursive=self.recursive
|
||||
)
|
||||
expanded_paths = glob.glob(wildcard_paths[0], recursive=self.recursive)
|
||||
else:
|
||||
logger.debug(
|
||||
"resolving {} paths with wildcards in {}",
|
||||
@@ -247,9 +245,7 @@ class DataSet(object):
|
||||
if self.root_dir is None:
|
||||
self._absolute_paths = sorted(self.paths)
|
||||
else:
|
||||
self._absolute_paths = [
|
||||
os.path.join(self.root_dir, p) for p in sorted(self.paths)
|
||||
]
|
||||
self._absolute_paths = [os.path.join(self.root_dir, p) for p in sorted(self.paths)]
|
||||
return self._absolute_paths
|
||||
|
||||
def sql_query_fragment(
|
||||
@@ -340,23 +336,15 @@ class DataSet(object):
|
||||
return file_partitions
|
||||
|
||||
@functools.lru_cache
|
||||
def partition_by_files(
|
||||
self, npartition: int, random_shuffle: bool = False
|
||||
) -> "List[DataSet]":
|
||||
def partition_by_files(self, npartition: int, random_shuffle: bool = False) -> "List[DataSet]":
|
||||
"""
|
||||
Evenly split into `npartition` datasets by files.
|
||||
"""
|
||||
assert npartition > 0, f"npartition has negative value: {npartition}"
|
||||
if npartition > self.num_files:
|
||||
logger.debug(
|
||||
f"number of partitions {npartition} is greater than the number of files {self.num_files}"
|
||||
)
|
||||
logger.debug(f"number of partitions {npartition} is greater than the number of files {self.num_files}")
|
||||
|
||||
resolved_paths = (
|
||||
random.sample(self.resolved_paths, len(self.resolved_paths))
|
||||
if random_shuffle
|
||||
else self.resolved_paths
|
||||
)
|
||||
resolved_paths = random.sample(self.resolved_paths, len(self.resolved_paths)) if random_shuffle else self.resolved_paths
|
||||
evenly_split_groups = split_into_rows(resolved_paths, npartition)
|
||||
num_paths_in_groups = list(map(len, evenly_split_groups))
|
||||
|
||||
@@ -367,11 +355,7 @@ class DataSet(object):
|
||||
logger.debug(
|
||||
f"created {npartition} file partitions (min #files: {min(num_paths_in_groups)}, max #files: {max(num_paths_in_groups)}, avg #files: {sum(num_paths_in_groups)/len(num_paths_in_groups):.3f}) from {self}"
|
||||
)
|
||||
return (
|
||||
random.sample(file_partitions, len(file_partitions))
|
||||
if random_shuffle
|
||||
else file_partitions
|
||||
)
|
||||
return random.sample(file_partitions, len(file_partitions)) if random_shuffle else file_partitions
|
||||
|
||||
|
||||
class PartitionedDataSet(DataSet):
|
||||
@@ -463,9 +447,7 @@ class CsvDataSet(DataSet):
|
||||
union_by_name=False,
|
||||
) -> None:
|
||||
super().__init__(paths, root_dir, recursive, columns, union_by_name)
|
||||
assert isinstance(
|
||||
schema, OrderedDict
|
||||
), f"type of csv schema is not OrderedDict: {type(schema)}"
|
||||
assert isinstance(schema, OrderedDict), f"type of csv schema is not OrderedDict: {type(schema)}"
|
||||
self.schema = schema
|
||||
self.delim = delim
|
||||
self.max_line_size = max_line_size
|
||||
@@ -492,14 +474,8 @@ class CsvDataSet(DataSet):
|
||||
filesystem: fsspec.AbstractFileSystem = None,
|
||||
conn: duckdb.DuckDBPyConnection = None,
|
||||
) -> str:
|
||||
schema_str = ", ".join(
|
||||
map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items())
|
||||
)
|
||||
max_line_size_str = (
|
||||
f"max_line_size={self.max_line_size}, "
|
||||
if self.max_line_size is not None
|
||||
else ""
|
||||
)
|
||||
schema_str = ", ".join(map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()))
|
||||
max_line_size_str = f"max_line_size={self.max_line_size}, " if self.max_line_size is not None else ""
|
||||
return (
|
||||
f"( select {self._column_str} from read_csv([ {self._resolved_path_str} ], delim='{self.delim}', columns={{ {schema_str} }}, header={self.header}, "
|
||||
f"{max_line_size_str} parallel={self.parallel}, union_by_name={self.union_by_name}) )"
|
||||
@@ -552,9 +528,7 @@ class JsonDataSet(DataSet):
|
||||
filesystem: fsspec.AbstractFileSystem = None,
|
||||
conn: duckdb.DuckDBPyConnection = None,
|
||||
) -> str:
|
||||
schema_str = ", ".join(
|
||||
map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items())
|
||||
)
|
||||
schema_str = ", ".join(map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()))
|
||||
return (
|
||||
f"( select {self._column_str} from read_json([ {self._resolved_path_str} ], format='{self.format}', columns={{ {schema_str} }}, "
|
||||
f"maximum_object_size={self.max_object_size}, union_by_name={self.union_by_name}) )"
|
||||
@@ -614,11 +588,7 @@ class ParquetDataSet(DataSet):
|
||||
)
|
||||
# merge row ranges if any dataset has resolved row ranges
|
||||
if any(dataset._resolved_row_ranges is not None for dataset in datasets):
|
||||
dataset._resolved_row_ranges = [
|
||||
row_range
|
||||
for dataset in datasets
|
||||
for row_range in dataset.resolved_row_ranges
|
||||
]
|
||||
dataset._resolved_row_ranges = [row_range for dataset in datasets for row_range in dataset.resolved_row_ranges]
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@@ -655,10 +625,7 @@ class ParquetDataSet(DataSet):
|
||||
# read parquet metadata to get number of rows
|
||||
metadata = parquet.read_metadata(path)
|
||||
num_rows = metadata.num_rows
|
||||
uncompressed_data_size = sum(
|
||||
metadata.row_group(i).total_byte_size
|
||||
for i in range(metadata.num_row_groups)
|
||||
)
|
||||
uncompressed_data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
|
||||
return RowRange(
|
||||
path,
|
||||
data_size=uncompressed_data_size,
|
||||
@@ -667,20 +634,14 @@ class ParquetDataSet(DataSet):
|
||||
end=num_rows,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=min(32, len(self.resolved_paths))
|
||||
) as pool:
|
||||
self._resolved_row_ranges = list(
|
||||
pool.map(resolve_row_range, self.resolved_paths)
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=min(32, len(self.resolved_paths))) as pool:
|
||||
self._resolved_row_ranges = list(pool.map(resolve_row_range, self.resolved_paths))
|
||||
return self._resolved_row_ranges
|
||||
|
||||
@property
|
||||
def num_rows(self) -> int:
|
||||
if self._resolved_num_rows is None:
|
||||
self._resolved_num_rows = sum(
|
||||
row_range.num_rows for row_range in self.resolved_row_ranges
|
||||
)
|
||||
self._resolved_num_rows = sum(row_range.num_rows for row_range in self.resolved_row_ranges)
|
||||
return self._resolved_num_rows
|
||||
|
||||
@property
|
||||
@@ -695,29 +656,19 @@ class ParquetDataSet(DataSet):
|
||||
"""
|
||||
Return the estimated data size in bytes.
|
||||
"""
|
||||
return sum(
|
||||
row_range.estimated_data_size for row_range in self.resolved_row_ranges
|
||||
)
|
||||
return sum(row_range.estimated_data_size for row_range in self.resolved_row_ranges)
|
||||
|
||||
def sql_query_fragment(
|
||||
self,
|
||||
filesystem: fsspec.AbstractFileSystem = None,
|
||||
conn: duckdb.DuckDBPyConnection = None,
|
||||
) -> str:
|
||||
extra_parameters = (
|
||||
"".join(f", {col}=true" for col in self.generated_columns)
|
||||
if self.generated_columns
|
||||
else ""
|
||||
)
|
||||
extra_parameters = "".join(f", {col}=true" for col in self.generated_columns) if self.generated_columns else ""
|
||||
parquet_file_queries = []
|
||||
full_row_ranges = []
|
||||
|
||||
for row_range in self.resolved_row_ranges:
|
||||
path = (
|
||||
filesystem.unstrip_protocol(row_range.path)
|
||||
if filesystem
|
||||
else row_range.path
|
||||
)
|
||||
path = filesystem.unstrip_protocol(row_range.path) if filesystem else row_range.path
|
||||
if row_range.num_rows == row_range.file_num_rows:
|
||||
full_row_ranges.append(row_range)
|
||||
else:
|
||||
@@ -744,9 +695,7 @@ class ParquetDataSet(DataSet):
|
||||
full_row_ranges[largest_index],
|
||||
full_row_ranges[0],
|
||||
)
|
||||
parquet_file_str = ",\n ".join(
|
||||
map(lambda x: f"'{x.path}'", full_row_ranges)
|
||||
)
|
||||
parquet_file_str = ",\n ".join(map(lambda x: f"'{x.path}'", full_row_ranges))
|
||||
parquet_file_queries.insert(
|
||||
0,
|
||||
f"""
|
||||
@@ -771,11 +720,7 @@ class ParquetDataSet(DataSet):
|
||||
|
||||
tables = []
|
||||
if self.resolved_row_ranges:
|
||||
tables.append(
|
||||
load_from_parquet_files(
|
||||
self.resolved_row_ranges, self.columns, max_workers, filesystem
|
||||
)
|
||||
)
|
||||
tables.append(load_from_parquet_files(self.resolved_row_ranges, self.columns, max_workers, filesystem))
|
||||
return arrow.concat_tables(tables)
|
||||
|
||||
def to_batch_reader(
|
||||
@@ -795,18 +740,14 @@ class ParquetDataSet(DataSet):
|
||||
)
|
||||
|
||||
@functools.lru_cache
|
||||
def partition_by_files(
|
||||
self, npartition: int, random_shuffle: bool = False
|
||||
) -> "List[ParquetDataSet]":
|
||||
def partition_by_files(self, npartition: int, random_shuffle: bool = False) -> "List[ParquetDataSet]":
|
||||
if self._resolved_row_ranges is not None:
|
||||
return self.partition_by_rows(npartition, random_shuffle)
|
||||
else:
|
||||
return super().partition_by_files(npartition, random_shuffle)
|
||||
|
||||
@functools.lru_cache
|
||||
def partition_by_rows(
|
||||
self, npartition: int, random_shuffle: bool = False
|
||||
) -> "List[ParquetDataSet]":
|
||||
def partition_by_rows(self, npartition: int, random_shuffle: bool = False) -> "List[ParquetDataSet]":
|
||||
"""
|
||||
Evenly split the dataset into `npartition` partitions by rows.
|
||||
If `random_shuffle` is True, shuffle the files before partitioning.
|
||||
@@ -814,11 +755,7 @@ class ParquetDataSet(DataSet):
|
||||
assert npartition > 0, f"npartition has negative value: {npartition}"
|
||||
|
||||
resolved_row_ranges = self.resolved_row_ranges
|
||||
resolved_row_ranges = (
|
||||
random.sample(resolved_row_ranges, len(resolved_row_ranges))
|
||||
if random_shuffle
|
||||
else resolved_row_ranges
|
||||
)
|
||||
resolved_row_ranges = random.sample(resolved_row_ranges, len(resolved_row_ranges)) if random_shuffle else resolved_row_ranges
|
||||
|
||||
def create_dataset(row_ranges: List[RowRange]) -> ParquetDataSet:
|
||||
row_ranges = sorted(row_ranges, key=lambda x: x.path)
|
||||
@@ -833,12 +770,7 @@ class ParquetDataSet(DataSet):
|
||||
dataset._resolved_row_ranges = row_ranges
|
||||
return dataset
|
||||
|
||||
return [
|
||||
create_dataset(row_ranges)
|
||||
for row_ranges in RowRange.partition_by_rows(
|
||||
resolved_row_ranges, npartition
|
||||
)
|
||||
]
|
||||
return [create_dataset(row_ranges) for row_ranges in RowRange.partition_by_rows(resolved_row_ranges, npartition)]
|
||||
|
||||
@functools.lru_cache
|
||||
def partition_by_size(self, max_partition_size: int) -> "List[ParquetDataSet]":
|
||||
@@ -847,16 +779,12 @@ class ParquetDataSet(DataSet):
|
||||
"""
|
||||
if self.empty:
|
||||
return []
|
||||
estimated_data_size = sum(
|
||||
row_range.estimated_data_size for row_range in self.resolved_row_ranges
|
||||
)
|
||||
estimated_data_size = sum(row_range.estimated_data_size for row_range in self.resolved_row_ranges)
|
||||
npartition = estimated_data_size // max_partition_size + 1
|
||||
return self.partition_by_rows(npartition)
|
||||
|
||||
@staticmethod
|
||||
def _read_partition_key(
|
||||
path: str, data_partition_column: str, hive_partitioning: bool
|
||||
) -> int:
|
||||
def _read_partition_key(path: str, data_partition_column: str, hive_partitioning: bool) -> int:
|
||||
"""
|
||||
Get the partition key of the parquet file.
|
||||
|
||||
@@ -874,9 +802,7 @@ class ParquetDataSet(DataSet):
|
||||
try:
|
||||
return int(key)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"cannot parse partition key '{data_partition_column}' of {path} from: {key}"
|
||||
)
|
||||
logger.error(f"cannot parse partition key '{data_partition_column}' of {path} from: {key}")
|
||||
raise
|
||||
|
||||
if hive_partitioning:
|
||||
@@ -884,9 +810,7 @@ class ParquetDataSet(DataSet):
|
||||
for part in path.split(os.path.sep):
|
||||
if part.startswith(path_part_prefix):
|
||||
return parse_partition_key(part[len(path_part_prefix) :])
|
||||
raise RuntimeError(
|
||||
f"cannot extract hive partition key '{data_partition_column}' from path: {path}"
|
||||
)
|
||||
raise RuntimeError(f"cannot extract hive partition key '{data_partition_column}' from path: {path}")
|
||||
|
||||
with parquet.ParquetFile(path) as file:
|
||||
kv_metadata = file.schema_arrow.metadata or file.metadata.metadata
|
||||
@@ -896,36 +820,22 @@ class ParquetDataSet(DataSet):
|
||||
if key == PARQUET_METADATA_KEY_PREFIX + data_partition_column:
|
||||
return parse_partition_key(val)
|
||||
if file.metadata.num_rows == 0:
|
||||
logger.warning(
|
||||
f"cannot read partition keys from empty parquet file: {path}"
|
||||
)
|
||||
logger.warning(f"cannot read partition keys from empty parquet file: {path}")
|
||||
return None
|
||||
for batch in file.iter_batches(
|
||||
batch_size=128, columns=[data_partition_column], use_threads=False
|
||||
):
|
||||
assert (
|
||||
data_partition_column in batch.column_names
|
||||
), f"cannot find column '{data_partition_column}' in {batch.column_names}"
|
||||
assert (
|
||||
batch.num_columns == 1
|
||||
), f"unexpected num of columns: {batch.column_names}"
|
||||
for batch in file.iter_batches(batch_size=128, columns=[data_partition_column], use_threads=False):
|
||||
assert data_partition_column in batch.column_names, f"cannot find column '{data_partition_column}' in {batch.column_names}"
|
||||
assert batch.num_columns == 1, f"unexpected num of columns: {batch.column_names}"
|
||||
uniq_partition_keys = set(batch.columns[0].to_pylist())
|
||||
assert (
|
||||
uniq_partition_keys and len(uniq_partition_keys) == 1
|
||||
), f"partition keys found in {path} not unique: {uniq_partition_keys}"
|
||||
assert uniq_partition_keys and len(uniq_partition_keys) == 1, f"partition keys found in {path} not unique: {uniq_partition_keys}"
|
||||
return uniq_partition_keys.pop()
|
||||
|
||||
def load_partitioned_datasets(
|
||||
self, npartition: int, data_partition_column: str, hive_partitioning=False
|
||||
) -> "List[ParquetDataSet]":
|
||||
def load_partitioned_datasets(self, npartition: int, data_partition_column: str, hive_partitioning=False) -> "List[ParquetDataSet]":
|
||||
"""
|
||||
Split the dataset into a list of partitioned datasets.
|
||||
"""
|
||||
assert npartition > 0, f"npartition has negative value: {npartition}"
|
||||
if npartition > self.num_files:
|
||||
logger.debug(
|
||||
f"number of partitions {npartition} is greater than the number of files {self.num_files}"
|
||||
)
|
||||
logger.debug(f"number of partitions {npartition} is greater than the number of files {self.num_files}")
|
||||
|
||||
file_partitions: List[ParquetDataSet] = self._init_file_partitions(npartition)
|
||||
for dataset in file_partitions:
|
||||
@@ -940,17 +850,13 @@ class ParquetDataSet(DataSet):
|
||||
|
||||
with ThreadPoolExecutor(min(32, len(self.resolved_paths))) as pool:
|
||||
partition_keys = pool.map(
|
||||
lambda path: ParquetDataSet._read_partition_key(
|
||||
path, data_partition_column, hive_partitioning
|
||||
),
|
||||
lambda path: ParquetDataSet._read_partition_key(path, data_partition_column, hive_partitioning),
|
||||
self.resolved_paths,
|
||||
)
|
||||
|
||||
for row_range, partition_key in zip(self.resolved_row_ranges, partition_keys):
|
||||
if partition_key is not None:
|
||||
assert (
|
||||
0 <= partition_key <= npartition
|
||||
), f"invalid partition key {partition_key} found in {row_range.path}"
|
||||
assert 0 <= partition_key <= npartition, f"invalid partition key {partition_key} found in {row_range.path}"
|
||||
dataset = file_partitions[partition_key]
|
||||
dataset.paths.append(row_range.path)
|
||||
dataset._absolute_paths.append(row_range.path)
|
||||
@@ -964,20 +870,14 @@ class ParquetDataSet(DataSet):
|
||||
"""
|
||||
Remove empty parquet files from the dataset.
|
||||
"""
|
||||
new_row_ranges = [
|
||||
row_range
|
||||
for row_range in self.resolved_row_ranges
|
||||
if row_range.num_rows > 0
|
||||
]
|
||||
new_row_ranges = [row_range for row_range in self.resolved_row_ranges if row_range.num_rows > 0]
|
||||
if len(new_row_ranges) == 0:
|
||||
# keep at least one file to avoid empty dataset
|
||||
new_row_ranges = self.resolved_row_ranges[:1]
|
||||
if len(new_row_ranges) == len(self.resolved_row_ranges):
|
||||
# no empty files found
|
||||
return
|
||||
logger.info(
|
||||
f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}"
|
||||
)
|
||||
logger.info(f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}")
|
||||
self._resolved_row_ranges = new_row_ranges
|
||||
self._resolved_paths = [row_range.path for row_range in new_row_ranges]
|
||||
self._absolute_paths = self._resolved_paths
|
||||
@@ -997,9 +897,7 @@ class SqlQueryDataSet(DataSet):
|
||||
def __init__(
|
||||
self,
|
||||
sql_query: str,
|
||||
query_builder: Callable[
|
||||
[duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str
|
||||
] = None,
|
||||
query_builder: Callable[[duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str] = None,
|
||||
) -> None:
|
||||
super().__init__([])
|
||||
self.sql_query = sql_query
|
||||
@@ -1007,9 +905,7 @@ class SqlQueryDataSet(DataSet):
|
||||
|
||||
@property
|
||||
def num_rows(self) -> int:
|
||||
num_rows = duckdb.sql(
|
||||
f"select count(*) as num_rows from {self.sql_query_fragment()}"
|
||||
).fetchall()
|
||||
num_rows = duckdb.sql(f"select count(*) as num_rows from {self.sql_query_fragment()}").fetchall()
|
||||
return num_rows[0][0]
|
||||
|
||||
def sql_query_fragment(
|
||||
@@ -1017,11 +913,7 @@ class SqlQueryDataSet(DataSet):
|
||||
filesystem: fsspec.AbstractFileSystem = None,
|
||||
conn: duckdb.DuckDBPyConnection = None,
|
||||
) -> str:
|
||||
sql_query = (
|
||||
self.sql_query
|
||||
if self.query_builder is None
|
||||
else self.query_builder(conn, filesystem)
|
||||
)
|
||||
sql_query = self.sql_query if self.query_builder is None else self.query_builder(conn, filesystem)
|
||||
return f"( {sql_query} )"
|
||||
|
||||
|
||||
|
||||
@@ -132,9 +132,7 @@ class Context(object):
|
||||
-------
|
||||
The unique function name.
|
||||
"""
|
||||
self.udfs[name] = PythonUDFContext(
|
||||
name, func, params, return_type, use_arrow_type
|
||||
)
|
||||
self.udfs[name] = PythonUDFContext(name, func, params, return_type, use_arrow_type)
|
||||
return name
|
||||
|
||||
def create_external_module(self, module_path: str, name: str = None) -> str:
|
||||
@@ -213,18 +211,10 @@ class Node(object):
|
||||
This is a resource requirement specified by the user and used to guide
|
||||
task scheduling. smallpond does NOT enforce this limit.
|
||||
"""
|
||||
assert isinstance(
|
||||
input_deps, Iterable
|
||||
), f"input_deps is not iterable: {input_deps}"
|
||||
assert all(
|
||||
isinstance(node, Node) for node in input_deps
|
||||
), f"some of input_deps are not instances of Node: {input_deps}"
|
||||
assert output_name is None or re.match(
|
||||
"[a-zA-Z0-9_]+", output_name
|
||||
), f"output_name has invalid format: {output_name}"
|
||||
assert output_path is None or os.path.isabs(
|
||||
output_path
|
||||
), f"output_path is not an absolute path: {output_path}"
|
||||
assert isinstance(input_deps, Iterable), f"input_deps is not iterable: {input_deps}"
|
||||
assert all(isinstance(node, Node) for node in input_deps), f"some of input_deps are not instances of Node: {input_deps}"
|
||||
assert output_name is None or re.match("[a-zA-Z0-9_]+", output_name), f"output_name has invalid format: {output_name}"
|
||||
assert output_path is None or os.path.isabs(output_path), f"output_path is not an absolute path: {output_path}"
|
||||
self.ctx = ctx
|
||||
self.id = self.ctx._new_node_id()
|
||||
self.input_deps = input_deps
|
||||
@@ -238,10 +228,7 @@ class Node(object):
|
||||
self.perf_metrics: Dict[str, List[float]] = defaultdict(list)
|
||||
# record the location where the node is constructed in user code
|
||||
frame = next(
|
||||
frame
|
||||
for frame in reversed(traceback.extract_stack())
|
||||
if frame.filename != __file__
|
||||
and not frame.filename.endswith("/dataframe.py")
|
||||
frame for frame in reversed(traceback.extract_stack()) if frame.filename != __file__ and not frame.filename.endswith("/dataframe.py")
|
||||
)
|
||||
self.location = f"{frame.filename}:{frame.lineno}"
|
||||
|
||||
@@ -290,9 +277,7 @@ class Node(object):
|
||||
values = self.perf_metrics[name]
|
||||
min, max, avg = np.min(values), np.max(values), np.average(values)
|
||||
p50, p75, p95, p99 = np.percentile(values, (50, 75, 95, 99))
|
||||
self.perf_stats[name] = PerfStats(
|
||||
len(values), sum(values), min, max, avg, p50, p75, p95, p99
|
||||
)
|
||||
self.perf_stats[name] = PerfStats(len(values), sum(values), min, max, avg, p50, p75, p95, p99)
|
||||
return self.perf_stats[name]
|
||||
|
||||
@property
|
||||
@@ -378,9 +363,7 @@ class DataSinkNode(Node):
|
||||
"link_or_copy",
|
||||
"manifest",
|
||||
), f"invalid sink type: {type}"
|
||||
super().__init__(
|
||||
ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0
|
||||
)
|
||||
super().__init__(ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0)
|
||||
self.type: DataSinkType = "manifest" if manifest_only else type
|
||||
self.is_final_node = is_final_node
|
||||
|
||||
@@ -402,12 +385,7 @@ class DataSinkNode(Node):
|
||||
if self.type == "copy" or self.type == "link_or_copy":
|
||||
# so we create two phase tasks:
|
||||
# phase1: copy data to a temp directory, for each input partition in parallel
|
||||
input_deps = [
|
||||
self._create_phase1_task(
|
||||
runtime_ctx, task, [PartitionInfo(i, len(input_deps))]
|
||||
)
|
||||
for i, task in enumerate(input_deps)
|
||||
]
|
||||
input_deps = [self._create_phase1_task(runtime_ctx, task, [PartitionInfo(i, len(input_deps))]) for i, task in enumerate(input_deps)]
|
||||
# phase2: resolve file name conflicts, hard link files, create manifest file, and clean up temp directory
|
||||
return DataSinkTask(
|
||||
runtime_ctx,
|
||||
@@ -445,9 +423,7 @@ class DataSinkNode(Node):
|
||||
input_dep: Task,
|
||||
partition_infos: List[PartitionInfo],
|
||||
) -> DataSinkTask:
|
||||
return DataSinkTask(
|
||||
runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type
|
||||
)
|
||||
return DataSinkTask(runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type)
|
||||
|
||||
|
||||
class PythonScriptNode(Node):
|
||||
@@ -467,9 +443,7 @@ class PythonScriptNode(Node):
|
||||
ctx: Context,
|
||||
input_deps: Tuple[Node, ...],
|
||||
*,
|
||||
process_func: Optional[
|
||||
Callable[[RuntimeContext, List[DataSet], str], bool]
|
||||
] = None,
|
||||
process_func: Optional[Callable[[RuntimeContext, List[DataSet], str], bool]] = None,
|
||||
output_name: Optional[str] = None,
|
||||
output_path: Optional[str] = None,
|
||||
cpu_limit: int = 1,
|
||||
@@ -646,9 +620,7 @@ class ArrowComputeNode(Node):
|
||||
gpu_limit,
|
||||
memory_limit,
|
||||
)
|
||||
self.parquet_row_group_size = (
|
||||
parquet_row_group_size or self.default_row_group_size
|
||||
)
|
||||
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
|
||||
self.parquet_dictionary_encoding = parquet_dictionary_encoding
|
||||
self.parquet_compression = parquet_compression
|
||||
self.parquet_compression_level = parquet_compression_level
|
||||
@@ -708,9 +680,7 @@ class ArrowComputeNode(Node):
|
||||
"""
|
||||
return ArrowComputeTask(*args, **kwargs)
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
|
||||
"""
|
||||
Put user-defined code here.
|
||||
|
||||
@@ -749,9 +719,7 @@ class ArrowStreamNode(Node):
|
||||
ctx: Context,
|
||||
input_deps: Tuple[Node, ...],
|
||||
*,
|
||||
process_func: Callable[
|
||||
[RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table]
|
||||
] = None,
|
||||
process_func: Callable[[RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table]] = None,
|
||||
background_io_thread=True,
|
||||
streaming_batch_size: int = None,
|
||||
secs_checkpoint_interval: int = None,
|
||||
@@ -816,12 +784,9 @@ class ArrowStreamNode(Node):
|
||||
self.background_io_thread = background_io_thread and self.cpu_limit > 1
|
||||
self.streaming_batch_size = streaming_batch_size or self.default_batch_size
|
||||
self.secs_checkpoint_interval = secs_checkpoint_interval or math.ceil(
|
||||
self.default_secs_checkpoint_interval
|
||||
/ min(6, self.gpu_limit + 2, self.cpu_limit)
|
||||
)
|
||||
self.parquet_row_group_size = (
|
||||
parquet_row_group_size or self.default_row_group_size
|
||||
self.default_secs_checkpoint_interval / min(6, self.gpu_limit + 2, self.cpu_limit)
|
||||
)
|
||||
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
|
||||
self.parquet_dictionary_encoding = parquet_dictionary_encoding
|
||||
self.parquet_compression = parquet_compression
|
||||
self.parquet_compression_level = parquet_compression_level
|
||||
@@ -890,9 +855,7 @@ class ArrowStreamNode(Node):
|
||||
"""
|
||||
return ArrowStreamTask(*args, **kwargs)
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
|
||||
) -> Iterable[arrow.Table]:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
|
||||
"""
|
||||
Put user-defined code here.
|
||||
|
||||
@@ -918,9 +881,7 @@ class ArrowBatchNode(ArrowStreamNode):
|
||||
def spawn(self, *args, **kwargs) -> ArrowBatchTask:
|
||||
return ArrowBatchTask(*args, **kwargs)
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -932,9 +893,7 @@ class PandasComputeNode(ArrowComputeNode):
|
||||
def spawn(self, *args, **kwargs) -> PandasComputeTask:
|
||||
return PandasComputeTask(*args, **kwargs)
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]
|
||||
) -> pd.DataFrame:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -946,9 +905,7 @@ class PandasBatchNode(ArrowStreamNode):
|
||||
def spawn(self, *args, **kwargs) -> PandasBatchTask:
|
||||
return PandasBatchTask(*args, **kwargs)
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]
|
||||
) -> pd.DataFrame:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1064,13 +1021,8 @@ class SqlEngineNode(Node):
|
||||
cpu_limit = cpu_limit or self.default_cpu_limit
|
||||
memory_limit = memory_limit or self.default_memory_limit
|
||||
if udfs is not None:
|
||||
if (
|
||||
self.max_udf_cpu_limit is not None
|
||||
and cpu_limit > self.max_udf_cpu_limit
|
||||
):
|
||||
warnings.warn(
|
||||
f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}"
|
||||
)
|
||||
if self.max_udf_cpu_limit is not None and cpu_limit > self.max_udf_cpu_limit:
|
||||
warnings.warn(f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}")
|
||||
cpu_limit = self.max_udf_cpu_limit
|
||||
memory_limit = None
|
||||
if relax_memory_if_oom is not None:
|
||||
@@ -1080,10 +1032,7 @@ class SqlEngineNode(Node):
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
assert isinstance(sql_query, str) or (
|
||||
isinstance(sql_query, Iterable)
|
||||
and all(isinstance(q, str) for q in sql_query)
|
||||
)
|
||||
assert isinstance(sql_query, str) or (isinstance(sql_query, Iterable) and all(isinstance(q, str) for q in sql_query))
|
||||
super().__init__(
|
||||
ctx,
|
||||
input_deps,
|
||||
@@ -1092,17 +1041,13 @@ class SqlEngineNode(Node):
|
||||
cpu_limit=cpu_limit,
|
||||
memory_limit=memory_limit,
|
||||
)
|
||||
self.sql_queries = (
|
||||
[sql_query] if isinstance(sql_query, str) else list(sql_query)
|
||||
)
|
||||
self.udfs = [
|
||||
ctx.create_duckdb_extension(path) for path in extension_paths or []
|
||||
] + [ctx.create_external_module(path) for path in udf_module_paths or []]
|
||||
self.sql_queries = [sql_query] if isinstance(sql_query, str) else list(sql_query)
|
||||
self.udfs = [ctx.create_duckdb_extension(path) for path in extension_paths or []] + [
|
||||
ctx.create_external_module(path) for path in udf_module_paths or []
|
||||
]
|
||||
for udf in udfs or []:
|
||||
if isinstance(udf, UserDefinedFunction):
|
||||
name = ctx.create_function(
|
||||
udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type
|
||||
)
|
||||
name = ctx.create_function(udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type)
|
||||
else:
|
||||
assert isinstance(udf, str), f"udf must be a string: {udf}"
|
||||
if udf in ctx.udfs:
|
||||
@@ -1120,9 +1065,7 @@ class SqlEngineNode(Node):
|
||||
self.materialize_in_memory = materialize_in_memory
|
||||
self.batched_processing = batched_processing and len(input_deps) == 1
|
||||
self.enable_temp_directory = enable_temp_directory
|
||||
self.parquet_row_group_size = (
|
||||
parquet_row_group_size or self.default_row_group_size
|
||||
)
|
||||
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
|
||||
self.parquet_dictionary_encoding = parquet_dictionary_encoding
|
||||
self.parquet_compression = parquet_compression
|
||||
self.parquet_compression_level = parquet_compression_level
|
||||
@@ -1130,17 +1073,11 @@ class SqlEngineNode(Node):
|
||||
self.memory_overcommit_ratio = memory_overcommit_ratio
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
super().__str__()
|
||||
+ f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}"
|
||||
)
|
||||
return super().__str__() + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}"
|
||||
|
||||
@property
|
||||
def oneline_query(self) -> str:
|
||||
return "; ".join(
|
||||
" ".join(filter(None, map(str.strip, query.splitlines())))
|
||||
for query in self.sql_queries
|
||||
)
|
||||
return "; ".join(" ".join(filter(None, map(str.strip, query.splitlines()))) for query in self.sql_queries)
|
||||
|
||||
@Node.task_factory
|
||||
def create_task(
|
||||
@@ -1224,12 +1161,8 @@ class ConsolidateNode(Node):
|
||||
dimensions
|
||||
Partitions would be grouped by these `dimensions` and consolidated into larger partitions.
|
||||
"""
|
||||
assert isinstance(
|
||||
dimensions, Iterable
|
||||
), f"dimensions is not iterable: {dimensions}"
|
||||
assert all(
|
||||
isinstance(dim, str) for dim in dimensions
|
||||
), f"some dimensions are not strings: {dimensions}"
|
||||
assert isinstance(dimensions, Iterable), f"dimensions is not iterable: {dimensions}"
|
||||
assert all(isinstance(dim, str) for dim in dimensions), f"some dimensions are not strings: {dimensions}"
|
||||
super().__init__(ctx, [input_dep])
|
||||
self.dimensions = set(list(dimensions) + [PartitionInfo.toplevel_dimension])
|
||||
|
||||
@@ -1283,29 +1216,16 @@ class PartitionNode(Node):
|
||||
See unit tests in `test/test_partition.py`. For nested partition see `test_nested_partition`.
|
||||
Why nested partition? See **5.1 Partial Partitioning** of [Advanced partitioning techniques for massively distributed computation](https://dl.acm.org/doi/10.1145/2213836.2213839).
|
||||
"""
|
||||
assert isinstance(
|
||||
npartitions, int
|
||||
), f"npartitions is not an integer: {npartitions}"
|
||||
assert dimension is None or re.match(
|
||||
"[a-zA-Z0-9_]+", dimension
|
||||
), f"dimension has invalid format: {dimension}"
|
||||
assert not (
|
||||
nested and dimension is None
|
||||
), f"nested partition should have dimension"
|
||||
super().__init__(
|
||||
ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit
|
||||
)
|
||||
assert isinstance(npartitions, int), f"npartitions is not an integer: {npartitions}"
|
||||
assert dimension is None or re.match("[a-zA-Z0-9_]+", dimension), f"dimension has invalid format: {dimension}"
|
||||
assert not (nested and dimension is None), f"nested partition should have dimension"
|
||||
super().__init__(ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit)
|
||||
self.npartitions = npartitions
|
||||
self.dimension = (
|
||||
dimension if dimension is not None else PartitionInfo.default_dimension
|
||||
)
|
||||
self.dimension = dimension if dimension is not None else PartitionInfo.default_dimension
|
||||
self.nested = nested
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
super().__str__()
|
||||
+ f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}"
|
||||
)
|
||||
return super().__str__() + f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}"
|
||||
|
||||
@Node.task_factory
|
||||
def create_producer_task(
|
||||
@@ -1441,12 +1361,8 @@ class UserDefinedPartitionNode(PartitionNode):
|
||||
class UserPartitionedDataSourceNode(UserDefinedPartitionNode):
|
||||
max_num_producer_tasks = 1
|
||||
|
||||
def __init__(
|
||||
self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None
|
||||
) -> None:
|
||||
assert isinstance(partitioned_datasets, Iterable) and all(
|
||||
isinstance(dataset, DataSet) for dataset in partitioned_datasets
|
||||
)
|
||||
def __init__(self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None) -> None:
|
||||
assert isinstance(partitioned_datasets, Iterable) and all(isinstance(dataset, DataSet) for dataset in partitioned_datasets)
|
||||
super().__init__(
|
||||
ctx,
|
||||
[DataSourceNode(ctx, dataset=None)],
|
||||
@@ -1507,10 +1423,7 @@ class EvenlyDistributedPartitionNode(PartitionNode):
|
||||
self.random_shuffle = random_shuffle
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
super().__str__()
|
||||
+ f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}"
|
||||
)
|
||||
return super().__str__() + f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}"
|
||||
|
||||
@Node.task_factory
|
||||
def create_producer_task(
|
||||
@@ -1551,9 +1464,7 @@ class LoadPartitionedDataSetNode(PartitionNode):
|
||||
cpu_limit: int = 1,
|
||||
memory_limit: Optional[int] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dimension or data_partition_column
|
||||
), f"Both 'dimension' and 'data_partition_column' are none or empty"
|
||||
assert dimension or data_partition_column, f"Both 'dimension' and 'data_partition_column' are none or empty"
|
||||
super().__init__(
|
||||
ctx,
|
||||
input_deps,
|
||||
@@ -1567,10 +1478,7 @@ class LoadPartitionedDataSetNode(PartitionNode):
|
||||
self.hive_partitioning = hive_partitioning
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
super().__str__()
|
||||
+ f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}"
|
||||
)
|
||||
return super().__str__() + f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}"
|
||||
|
||||
@Node.task_factory
|
||||
def create_producer_task(
|
||||
@@ -1620,9 +1528,7 @@ def DataSetPartitionNode(
|
||||
--------
|
||||
See unit test `test_load_partitioned_datasets` in `test/test_partition.py`.
|
||||
"""
|
||||
assert not (
|
||||
partition_by_rows and data_partition_column
|
||||
), "partition_by_rows and data_partition_column cannot be set at the same time"
|
||||
assert not (partition_by_rows and data_partition_column), "partition_by_rows and data_partition_column cannot be set at the same time"
|
||||
if data_partition_column is None:
|
||||
partition_node = EvenlyDistributedPartitionNode(
|
||||
ctx,
|
||||
@@ -1720,12 +1626,8 @@ class HashPartitionNode(PartitionNode):
|
||||
Specify if we should use dictionary encoding in general or only for some columns.
|
||||
See `use_dictionary` in https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html.
|
||||
"""
|
||||
assert (
|
||||
not random_shuffle or not shuffle_only
|
||||
), f"random_shuffle and shuffle_only cannot be enabled at the same time"
|
||||
assert (
|
||||
not shuffle_only or data_partition_column is not None
|
||||
), f"data_partition_column not specified for shuffle-only partitioning"
|
||||
assert not random_shuffle or not shuffle_only, f"random_shuffle and shuffle_only cannot be enabled at the same time"
|
||||
assert not shuffle_only or data_partition_column is not None, f"data_partition_column not specified for shuffle-only partitioning"
|
||||
assert data_partition_column is None or re.match(
|
||||
"[a-zA-Z0-9_]+", data_partition_column
|
||||
), f"data_partition_column has invalid format: {data_partition_column}"
|
||||
@@ -1734,9 +1636,7 @@ class HashPartitionNode(PartitionNode):
|
||||
"duckdb",
|
||||
"arrow",
|
||||
), f"unknown query engine type: {engine_type}"
|
||||
data_partition_column = (
|
||||
data_partition_column or self.default_data_partition_column
|
||||
)
|
||||
data_partition_column = data_partition_column or self.default_data_partition_column
|
||||
super().__init__(
|
||||
ctx,
|
||||
input_deps,
|
||||
@@ -1756,9 +1656,7 @@ class HashPartitionNode(PartitionNode):
|
||||
self.drop_partition_column = drop_partition_column
|
||||
self.use_parquet_writer = use_parquet_writer
|
||||
self.hive_partitioning = hive_partitioning and self.engine_type == "duckdb"
|
||||
self.parquet_row_group_size = (
|
||||
parquet_row_group_size or self.default_row_group_size
|
||||
)
|
||||
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
|
||||
self.parquet_dictionary_encoding = parquet_dictionary_encoding
|
||||
self.parquet_compression = parquet_compression
|
||||
self.parquet_compression_level = parquet_compression_level
|
||||
@@ -1929,22 +1827,15 @@ class ProjectionNode(Node):
|
||||
"""
|
||||
columns = columns or ["*"]
|
||||
generated_columns = generated_columns or []
|
||||
assert all(
|
||||
col in GENERATED_COLUMNS for col in generated_columns
|
||||
), f"invalid values found in generated columns: {generated_columns}"
|
||||
assert not (
|
||||
set(columns) & set(generated_columns)
|
||||
), f"columns {columns} and generated columns {generated_columns} share common columns"
|
||||
assert all(col in GENERATED_COLUMNS for col in generated_columns), f"invalid values found in generated columns: {generated_columns}"
|
||||
assert not (set(columns) & set(generated_columns)), f"columns {columns} and generated columns {generated_columns} share common columns"
|
||||
super().__init__(ctx, [input_dep])
|
||||
self.columns = columns
|
||||
self.generated_columns = generated_columns
|
||||
self.union_by_name = union_by_name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
super().__str__()
|
||||
+ f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}"
|
||||
)
|
||||
return super().__str__() + f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}"
|
||||
|
||||
@Node.task_factory
|
||||
def create_task(
|
||||
@@ -2100,10 +1991,7 @@ class LogicalPlan(object):
|
||||
if node.id in visited:
|
||||
return lines + [" " * depth + " (omitted ...)"]
|
||||
visited.add(node.id)
|
||||
lines += [
|
||||
" " * depth + f" | {name}: {stats}"
|
||||
for name, stats in node.perf_stats.items()
|
||||
]
|
||||
lines += [" " * depth + f" | {name}: {stats}" for name, stats in node.perf_stats.items()]
|
||||
for dep in node.input_deps:
|
||||
lines.extend(to_str(dep, depth + 1))
|
||||
return lines
|
||||
|
||||
@@ -32,9 +32,7 @@ class Optimizer(LogicalPlanVisitor[Node]):
|
||||
|
||||
def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> Node:
|
||||
# fuse consecutive SqlEngineNodes
|
||||
if len(node.input_deps) == 1 and isinstance(
|
||||
child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode
|
||||
):
|
||||
if len(node.input_deps) == 1 and isinstance(child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode):
|
||||
fused = copy.copy(node)
|
||||
fused.input_deps = child.input_deps
|
||||
fused.udfs = node.udfs + child.udfs
|
||||
@@ -52,8 +50,6 @@ class Optimizer(LogicalPlanVisitor[Node]):
|
||||
# node.sql_queries = ["select a, b from {0}"]
|
||||
# fused.sql_queries = ["select a, b from (select * from {0})"]
|
||||
# ```
|
||||
fused.sql_queries = child.sql_queries[:-1] + [
|
||||
query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries
|
||||
]
|
||||
fused.sql_queries = child.sql_queries[:-1] + [query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries]
|
||||
return fused
|
||||
return self.generic_visit(node, depth)
|
||||
|
||||
@@ -14,30 +14,20 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
self.node_to_tasks: Dict[Node, TaskGroup] = {}
|
||||
|
||||
@logger.catch(reraise=True, message="failed to build computation graph")
|
||||
def create_exec_plan(
|
||||
self, logical_plan: LogicalPlan, manifest_only_final_results=True
|
||||
) -> ExecutionPlan:
|
||||
def create_exec_plan(self, logical_plan: LogicalPlan, manifest_only_final_results=True) -> ExecutionPlan:
|
||||
logical_plan = copy.deepcopy(logical_plan)
|
||||
|
||||
# if --output_path is specified, copy files to the output path
|
||||
# otherwise, create manifest files only
|
||||
sink_type = (
|
||||
"copy" if self.runtime_ctx.final_output_path is not None else "manifest"
|
||||
)
|
||||
final_sink_type = (
|
||||
"copy"
|
||||
if self.runtime_ctx.final_output_path is not None
|
||||
else "manifest" if manifest_only_final_results else "link"
|
||||
)
|
||||
sink_type = "copy" if self.runtime_ctx.final_output_path is not None else "manifest"
|
||||
final_sink_type = "copy" if self.runtime_ctx.final_output_path is not None else "manifest" if manifest_only_final_results else "link"
|
||||
|
||||
# create DataSinkNode for each named output node (same name share the same sink node)
|
||||
nodes_groupby_output_name: Dict[str, List[Node]] = defaultdict(list)
|
||||
for node in logical_plan.nodes.values():
|
||||
if node.output_name is not None:
|
||||
if node.output_name in nodes_groupby_output_name:
|
||||
warnings.warn(
|
||||
f"{node} has duplicate output name: {node.output_name}"
|
||||
)
|
||||
warnings.warn(f"{node} has duplicate output name: {node.output_name}")
|
||||
nodes_groupby_output_name[node.output_name].append(node)
|
||||
sink_nodes = {} # { output_name: DataSinkNode }
|
||||
for output_name, nodes in nodes_groupby_output_name.items():
|
||||
@@ -45,9 +35,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
|
||||
output_name,
|
||||
)
|
||||
sink_nodes[output_name] = DataSinkNode(
|
||||
logical_plan.ctx, tuple(nodes), output_path, type=sink_type
|
||||
)
|
||||
sink_nodes[output_name] = DataSinkNode(logical_plan.ctx, tuple(nodes), output_path, type=sink_type)
|
||||
|
||||
# create DataSinkNode for root node
|
||||
# XXX: special case optimization to avoid copying files twice
|
||||
@@ -63,9 +51,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
):
|
||||
sink_nodes["FinalResults"] = DataSinkNode(
|
||||
logical_plan.ctx,
|
||||
tuple(
|
||||
sink_nodes[node.output_name] for node in partition_node.input_deps
|
||||
),
|
||||
tuple(sink_nodes[node.output_name] for node in partition_node.input_deps),
|
||||
output_path=os.path.join(
|
||||
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
|
||||
"FinalResults",
|
||||
@@ -124,38 +110,24 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
return [node.create_task(self.runtime_ctx, [], [PartitionInfo()])]
|
||||
|
||||
def visit_data_sink_node(self, node: DataSinkNode, depth: int) -> TaskGroup:
|
||||
all_input_deps = [
|
||||
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
|
||||
]
|
||||
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
|
||||
return [node.create_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])]
|
||||
|
||||
def visit_root_node(self, node: RootNode, depth: int) -> TaskGroup:
|
||||
all_input_deps = [
|
||||
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
|
||||
]
|
||||
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
|
||||
return [RootTask(self.runtime_ctx, all_input_deps, [PartitionInfo()])]
|
||||
|
||||
def visit_union_node(self, node: UnionNode, depth: int) -> TaskGroup:
|
||||
all_input_deps = [
|
||||
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
|
||||
]
|
||||
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
|
||||
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
|
||||
assert (
|
||||
len(unique_partition_dims) == 1
|
||||
), f"cannot union partitions with different dimensions: {unique_partition_dims}"
|
||||
assert len(unique_partition_dims) == 1, f"cannot union partitions with different dimensions: {unique_partition_dims}"
|
||||
return all_input_deps
|
||||
|
||||
def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> TaskGroup:
|
||||
input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
|
||||
assert (
|
||||
len(input_deps_taskgroups) == 1
|
||||
), f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}"
|
||||
unique_partition_dims = set(
|
||||
task.partition_dims for task in input_deps_taskgroups[0]
|
||||
)
|
||||
assert (
|
||||
len(unique_partition_dims) == 1
|
||||
), f"cannot consolidate partitions with different dimensions: {unique_partition_dims}"
|
||||
assert len(input_deps_taskgroups) == 1, f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}"
|
||||
unique_partition_dims = set(task.partition_dims for task in input_deps_taskgroups[0])
|
||||
assert len(unique_partition_dims) == 1, f"cannot consolidate partitions with different dimensions: {unique_partition_dims}"
|
||||
existing_dimensions = set(unique_partition_dims.pop())
|
||||
assert (
|
||||
node.dimensions.intersection(existing_dimensions) == node.dimensions
|
||||
@@ -163,46 +135,30 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
# group tasks by partitions
|
||||
input_deps_groupby_partitions: Dict[Tuple, List[Task]] = defaultdict(list)
|
||||
for task in input_deps_taskgroups[0]:
|
||||
partition_infos = tuple(
|
||||
info
|
||||
for info in task.partition_infos
|
||||
if info.dimension in node.dimensions
|
||||
)
|
||||
partition_infos = tuple(info for info in task.partition_infos if info.dimension in node.dimensions)
|
||||
input_deps_groupby_partitions[partition_infos].append(task)
|
||||
return [
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos)
|
||||
for partition_infos, input_deps in input_deps_groupby_partitions.items()
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos) for partition_infos, input_deps in input_deps_groupby_partitions.items()
|
||||
]
|
||||
|
||||
def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup:
|
||||
all_input_deps = [
|
||||
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
|
||||
]
|
||||
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
|
||||
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
|
||||
assert (
|
||||
len(unique_partition_dims) == 1
|
||||
), f"cannot partition input_deps with different dimensions: {unique_partition_dims}"
|
||||
assert len(unique_partition_dims) == 1, f"cannot partition input_deps with different dimensions: {unique_partition_dims}"
|
||||
|
||||
if node.nested:
|
||||
assert (
|
||||
node.dimension not in unique_partition_dims
|
||||
), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}"
|
||||
assert (
|
||||
len(all_input_deps) * node.npartitions
|
||||
<= node.max_card_of_producers_x_consumers
|
||||
len(all_input_deps) * node.npartitions <= node.max_card_of_producers_x_consumers
|
||||
), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}"
|
||||
producer_tasks = [
|
||||
node.create_producer_task(
|
||||
self.runtime_ctx, [task], task.partition_infos
|
||||
)
|
||||
for task in all_input_deps
|
||||
]
|
||||
producer_tasks = [node.create_producer_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps]
|
||||
return [
|
||||
node.create_consumer_task(
|
||||
self.runtime_ctx,
|
||||
[producer],
|
||||
list(producer.partition_infos)
|
||||
+ [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
|
||||
list(producer.partition_infos) + [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
|
||||
)
|
||||
for producer in producer_tasks
|
||||
for partition_idx in range(node.npartitions)
|
||||
@@ -212,16 +168,10 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
node.max_num_producer_tasks,
|
||||
math.ceil(node.max_card_of_producers_x_consumers / node.npartitions),
|
||||
)
|
||||
num_parallel_tasks = (
|
||||
2
|
||||
* self.runtime_ctx.num_executors
|
||||
* math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
|
||||
)
|
||||
num_parallel_tasks = 2 * self.runtime_ctx.num_executors * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
|
||||
num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks))
|
||||
if len(all_input_deps) < num_producer_tasks:
|
||||
merge_datasets_task = node.create_merge_task(
|
||||
self.runtime_ctx, all_input_deps, [PartitionInfo()]
|
||||
)
|
||||
merge_datasets_task = node.create_merge_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])
|
||||
split_dataset_tasks = [
|
||||
node.create_split_task(
|
||||
self.runtime_ctx,
|
||||
@@ -237,15 +187,10 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
tasks,
|
||||
[PartitionInfo(partition_idx, num_producer_tasks)],
|
||||
)
|
||||
for partition_idx, tasks in enumerate(
|
||||
split_into_rows(all_input_deps, num_producer_tasks)
|
||||
)
|
||||
for partition_idx, tasks in enumerate(split_into_rows(all_input_deps, num_producer_tasks))
|
||||
]
|
||||
producer_tasks = [
|
||||
node.create_producer_task(
|
||||
self.runtime_ctx, [split_dataset], split_dataset.partition_infos
|
||||
)
|
||||
for split_dataset in split_dataset_tasks
|
||||
node.create_producer_task(self.runtime_ctx, [split_dataset], split_dataset.partition_infos) for split_dataset in split_dataset_tasks
|
||||
]
|
||||
return [
|
||||
node.create_consumer_task(
|
||||
@@ -284,11 +229,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
for main_input in input_deps_most_ndims:
|
||||
input_deps = []
|
||||
for input_deps_dims, input_deps_map in input_deps_maps:
|
||||
partition_infos = tuple(
|
||||
info
|
||||
for info in main_input.partition_infos
|
||||
if info.dimension in input_deps_dims
|
||||
)
|
||||
partition_infos = tuple(info for info in main_input.partition_infos if info.dimension in input_deps_dims)
|
||||
input_dep = input_deps_map.get(partition_infos, None)
|
||||
assert (
|
||||
input_dep is not None
|
||||
@@ -299,50 +240,32 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
|
||||
|
||||
def visit_python_script_node(self, node: PythonScriptNode, depth: int) -> TaskGroup:
|
||||
return [
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos)
|
||||
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
]
|
||||
|
||||
def visit_arrow_compute_node(self, node: ArrowComputeNode, depth: int) -> TaskGroup:
|
||||
return [
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos)
|
||||
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
]
|
||||
|
||||
def visit_arrow_stream_node(self, node: ArrowStreamNode, depth: int) -> TaskGroup:
|
||||
return [
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos)
|
||||
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
]
|
||||
|
||||
def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> TaskGroup:
|
||||
return [
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos)
|
||||
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
|
||||
]
|
||||
|
||||
def visit_projection_node(self, node: ProjectionNode, depth: int) -> TaskGroup:
|
||||
assert (
|
||||
len(node.input_deps) == 1
|
||||
), f"projection node only accepts one input node, but found: {node.input_deps}"
|
||||
return [
|
||||
node.create_task(self.runtime_ctx, [task], task.partition_infos)
|
||||
for task in self.visit(node.input_deps[0], depth + 1)
|
||||
]
|
||||
assert len(node.input_deps) == 1, f"projection node only accepts one input node, but found: {node.input_deps}"
|
||||
return [node.create_task(self.runtime_ctx, [task], task.partition_infos) for task in self.visit(node.input_deps[0], depth + 1)]
|
||||
|
||||
def visit_limit_node(self, node: LimitNode, depth: int) -> TaskGroup:
|
||||
assert (
|
||||
len(node.input_deps) == 1
|
||||
), f"limit node only accepts one input node, but found: {node.input_deps}"
|
||||
assert len(node.input_deps) == 1, f"limit node only accepts one input node, but found: {node.input_deps}"
|
||||
all_input_deps = self.visit(node.input_deps[0], depth + 1)
|
||||
partial_limit_tasks = [
|
||||
node.create_task(self.runtime_ctx, [task], task.partition_infos)
|
||||
for task in all_input_deps
|
||||
]
|
||||
merge_task = node.create_merge_task(
|
||||
self.runtime_ctx, partial_limit_tasks, [PartitionInfo()]
|
||||
)
|
||||
global_limit_task = node.create_task(
|
||||
self.runtime_ctx, [merge_task], merge_task.partition_infos
|
||||
)
|
||||
partial_limit_tasks = [node.create_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps]
|
||||
merge_task = node.create_merge_task(self.runtime_ctx, partial_limit_tasks, [PartitionInfo()])
|
||||
global_limit_task = node.create_task(self.runtime_ctx, [merge_task], merge_task.partition_infos)
|
||||
return [global_limit_task]
|
||||
|
||||
@@ -264,6 +264,4 @@ def udf(
|
||||
|
||||
See `Context.create_function` for more details.
|
||||
"""
|
||||
return lambda func: UserDefinedFunction(
|
||||
name or func.__name__, func, params, return_type, use_arrow_type
|
||||
)
|
||||
return lambda func: UserDefinedFunction(name or func.__name__, func, params, return_type, use_arrow_type)
|
||||
|
||||
@@ -57,9 +57,7 @@ class SessionBase:
|
||||
logger.info(f"session config: {self.config}")
|
||||
|
||||
def setup_worker():
|
||||
runtime_ctx._init_logs(
|
||||
exec_id=socket.gethostname(), capture_stdout_stderr=True
|
||||
)
|
||||
runtime_ctx._init_logs(exec_id=socket.gethostname(), capture_stdout_stderr=True)
|
||||
|
||||
if self.config.ray_address is None:
|
||||
# find the memory allocator
|
||||
@@ -72,9 +70,7 @@ class SessionBase:
|
||||
malloc_path = shutil.which("libmimalloc.so.2.1")
|
||||
assert malloc_path is not None, "mimalloc is not installed"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported memory allocator: {self.config.memory_allocator}"
|
||||
)
|
||||
raise ValueError(f"unsupported memory allocator: {self.config.memory_allocator}")
|
||||
memory_purge_delay = 10000
|
||||
|
||||
# start ray head node
|
||||
@@ -84,11 +80,7 @@ class SessionBase:
|
||||
# start a new local cluster
|
||||
address="local",
|
||||
# disable local CPU resource if not running on localhost
|
||||
num_cpus=(
|
||||
0
|
||||
if self.config.num_executors > 0
|
||||
else self._runtime_ctx.usable_cpu_count
|
||||
),
|
||||
num_cpus=(0 if self.config.num_executors > 0 else self._runtime_ctx.usable_cpu_count),
|
||||
# set the memory limit to the available memory size
|
||||
_memory=self._runtime_ctx.usable_memory_size,
|
||||
# setup logging for workers
|
||||
@@ -142,9 +134,7 @@ class SessionBase:
|
||||
|
||||
# spawn a thread to periodically dump metrics
|
||||
self._stop_event = threading.Event()
|
||||
self._dump_thread = threading.Thread(
|
||||
name="dump_thread", target=self._dump_periodically, daemon=True
|
||||
)
|
||||
self._dump_thread = threading.Thread(name="dump_thread", target=self._dump_periodically, daemon=True)
|
||||
self._dump_thread.start()
|
||||
|
||||
def shutdown(self):
|
||||
@@ -184,11 +174,7 @@ class SessionBase:
|
||||
extra_opts=dict(
|
||||
tags=["smallpond", "scheduler", smallpond.__version__],
|
||||
),
|
||||
envs={
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if k.startswith("SP_") and k != "SP_SPAWN"
|
||||
},
|
||||
envs={k: v for k, v in os.environ.items() if k.startswith("SP_") and k != "SP_SPAWN"},
|
||||
)
|
||||
|
||||
def _start_prometheus(self) -> Optional[subprocess.Popen]:
|
||||
@@ -233,8 +219,7 @@ class SessionBase:
|
||||
stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"),
|
||||
env={
|
||||
"GF_SERVER_HTTP_PORT": "8122", # redirect to an available port
|
||||
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST")
|
||||
or "http://localhost:8122",
|
||||
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST") or "http://localhost:8122",
|
||||
"GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data",
|
||||
},
|
||||
)
|
||||
@@ -309,12 +294,8 @@ class SessionBase:
|
||||
self.dump_graph()
|
||||
self.dump_timeline()
|
||||
num_total_tasks, num_finished_tasks = self._summarize_task()
|
||||
percent = (
|
||||
num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
|
||||
)
|
||||
logger.info(
|
||||
f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)"
|
||||
)
|
||||
percent = num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
|
||||
logger.info(f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -360,20 +341,12 @@ class Config:
|
||||
|
||||
platform = get_platform(get_env("PLATFORM") or platform)
|
||||
job_id = get_env("JOBID") or job_id or platform.default_job_id()
|
||||
job_time = (
|
||||
get_env("JOB_TIME", datetime.fromisoformat)
|
||||
or job_time
|
||||
or platform.default_job_time()
|
||||
)
|
||||
job_time = get_env("JOB_TIME", datetime.fromisoformat) or job_time or platform.default_job_time()
|
||||
data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root()
|
||||
num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0
|
||||
ray_address = get_env("RAY_ADDRESS") or ray_address
|
||||
bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node
|
||||
memory_allocator = (
|
||||
get_env("MEMORY_ALLOCATOR")
|
||||
or memory_allocator
|
||||
or platform.default_memory_allocator()
|
||||
)
|
||||
memory_allocator = get_env("MEMORY_ALLOCATOR") or memory_allocator or platform.default_memory_allocator()
|
||||
|
||||
config = Config(
|
||||
job_id=job_id,
|
||||
|
||||
@@ -24,9 +24,7 @@ def overall_stats(
|
||||
):
|
||||
from smallpond.logical.node import DataSetPartitionNode, DataSinkNode, SqlEngineNode
|
||||
|
||||
n = SqlEngineNode(
|
||||
ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit
|
||||
)
|
||||
n = SqlEngineNode(ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit)
|
||||
p = DataSetPartitionNode(ctx, (n,), npartitions=1)
|
||||
n2 = SqlEngineNode(
|
||||
ctx,
|
||||
@@ -59,9 +57,7 @@ def execute_command(cmd: str, env: Dict[str, str] = None, shell=False):
|
||||
raise subprocess.CalledProcessError(return_code, cmd)
|
||||
|
||||
|
||||
def cprofile_to_string(
|
||||
perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20
|
||||
):
|
||||
def cprofile_to_string(perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20):
|
||||
perf_profile.disable()
|
||||
pstats_output = io.StringIO()
|
||||
profile_stats = pstats.Stats(perf_profile, stream=pstats_output)
|
||||
@@ -111,9 +107,7 @@ class ConcurrentIter(object):
|
||||
"""
|
||||
|
||||
def __init__(self, iterable: Iterable, max_buffer_size=1) -> None:
|
||||
assert isinstance(
|
||||
iterable, Iterable
|
||||
), f"expect an iterable but found: {repr(iterable)}"
|
||||
assert isinstance(iterable, Iterable), f"expect an iterable but found: {repr(iterable)}"
|
||||
self.__iterable = iterable
|
||||
self.__queue = queue.Queue(max_buffer_size)
|
||||
self.__last = object()
|
||||
@@ -194,6 +188,4 @@ class InterceptHandler(logging.Handler):
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
@@ -14,9 +14,7 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="The address of the Ray cluster to connect to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_dir", required=True, help="The directory where logs will be stored"
|
||||
)
|
||||
parser.add_argument("--log_dir", required=True, help="The directory where logs will be stored")
|
||||
parser.add_argument(
|
||||
"--bind_numa_node",
|
||||
action="store_true",
|
||||
|
||||
@@ -14,19 +14,13 @@ from filelock import FileLock
|
||||
|
||||
|
||||
def generate_url_and_domain() -> Tuple[str, str]:
|
||||
domain_part = "".join(
|
||||
random.choices(string.ascii_lowercase, k=random.randint(5, 15))
|
||||
)
|
||||
domain_part = "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 15)))
|
||||
tld = random.choice(["com", "net", "org", "cn", "edu", "gov", "co", "io"])
|
||||
domain = f"www.{domain_part}.{tld}"
|
||||
|
||||
path_segments = []
|
||||
for _ in range(random.randint(1, 3)):
|
||||
segment = "".join(
|
||||
random.choices(
|
||||
string.ascii_lowercase + string.digits, k=random.randint(3, 10)
|
||||
)
|
||||
)
|
||||
segment = "".join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(3, 10)))
|
||||
path_segments.append(segment)
|
||||
path = "/" + "/".join(path_segments)
|
||||
|
||||
@@ -42,26 +36,18 @@ def generate_random_date() -> str:
|
||||
start = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2023, 12, 31, tzinfo=timezone.utc)
|
||||
delta = end - start
|
||||
random_date = start + timedelta(
|
||||
seconds=random.randint(0, int(delta.total_seconds()))
|
||||
)
|
||||
random_date = start + timedelta(seconds=random.randint(0, int(delta.total_seconds())))
|
||||
return random_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def generate_content() -> bytes:
|
||||
target_length = (
|
||||
random.randint(1000, 100000)
|
||||
if random.random() < 0.8
|
||||
else random.randint(100000, 1000000)
|
||||
)
|
||||
target_length = random.randint(1000, 100000) if random.random() < 0.8 else random.randint(100000, 1000000)
|
||||
before = b"<!DOCTYPE html><html><head><title>Random Page</title></head><body>"
|
||||
after = b"</body></html>"
|
||||
total_before_after = len(before) + len(after)
|
||||
|
||||
fill_length = max(target_length - total_before_after, 0)
|
||||
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[
|
||||
:fill_length
|
||||
]
|
||||
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[:fill_length]
|
||||
|
||||
return before + filler + after
|
||||
|
||||
@@ -103,9 +89,7 @@ def generate_random_string(length: int) -> str:
|
||||
def generate_random_url() -> str:
|
||||
"""Generate a random URL"""
|
||||
path = generate_random_string(random.randint(10, 20))
|
||||
return (
|
||||
f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
|
||||
)
|
||||
return f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
|
||||
|
||||
|
||||
def generate_random_data() -> str:
|
||||
@@ -138,9 +122,7 @@ def generate_url_parquet_files(output_dir: str, num_files: int = 10):
|
||||
)
|
||||
|
||||
|
||||
def generate_url_tsv_files(
|
||||
output_dir: str, num_files: int = 10, lines_per_file: int = 100
|
||||
):
|
||||
def generate_url_tsv_files(output_dir: str, num_files: int = 10, lines_per_file: int = 100):
|
||||
"""Generate multiple files, each containing a specified number of random data lines"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
for i in range(num_files):
|
||||
@@ -166,16 +148,12 @@ def generate_data(path: str = "tests/data"):
|
||||
with FileLock(path + "/data.lock"):
|
||||
print("Generating data...")
|
||||
if not os.path.exists(path + "/mock_urls"):
|
||||
generate_url_tsv_files(
|
||||
output_dir=path + "/mock_urls", num_files=10, lines_per_file=100
|
||||
)
|
||||
generate_url_tsv_files(output_dir=path + "/mock_urls", num_files=10, lines_per_file=100)
|
||||
generate_url_parquet_files(output_dir=path + "/mock_urls", num_files=10)
|
||||
if not os.path.exists(path + "/arrow"):
|
||||
generate_arrow_files(output_dir=path + "/arrow", num_files=10)
|
||||
if not os.path.exists(path + "/large_array"):
|
||||
concat_arrow_files(
|
||||
input_dir=path + "/arrow", output_dir=path + "/large_array"
|
||||
)
|
||||
concat_arrow_files(input_dir=path + "/arrow", output_dir=path + "/large_array")
|
||||
if not os.path.exists(path + "/long_path_list.txt"):
|
||||
generate_long_path_list(path=path + "/long_path_list.txt")
|
||||
except Exception as e:
|
||||
|
||||
@@ -37,10 +37,7 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
with self.subTest(dataset_path=dataset_path):
|
||||
metadata = parquet.read_metadata(dataset_path)
|
||||
file_num_rows = metadata.num_rows
|
||||
data_size = sum(
|
||||
metadata.row_group(i).total_byte_size
|
||||
for i in range(metadata.num_row_groups)
|
||||
)
|
||||
data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
|
||||
row_range = RowRange(
|
||||
path=dataset_path,
|
||||
begin=100,
|
||||
@@ -48,9 +45,7 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
data_size=data_size,
|
||||
file_num_rows=file_num_rows,
|
||||
)
|
||||
expected = self._load_parquet_files([dataset_path]).slice(
|
||||
offset=100, length=100
|
||||
)
|
||||
expected = self._load_parquet_files([dataset_path]).slice(offset=100, length=100)
|
||||
actual = load_from_parquet_files([row_range])
|
||||
self._compare_arrow_tables(expected, actual)
|
||||
|
||||
@@ -62,21 +57,15 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
with self.subTest(dataset_path=dataset_path):
|
||||
parquet_files = glob.glob(dataset_path)
|
||||
expected = self._load_parquet_files(parquet_files)
|
||||
with tempfile.TemporaryDirectory(
|
||||
dir=self.output_root_abspath
|
||||
) as output_dir:
|
||||
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
ok = dump_to_parquet_files(expected, output_dir)
|
||||
self.assertTrue(ok)
|
||||
actual = self._load_parquet_files(
|
||||
glob.glob(f"{output_dir}/*.parquet")
|
||||
)
|
||||
actual = self._load_parquet_files(glob.glob(f"{output_dir}/*.parquet"))
|
||||
self._compare_arrow_tables(expected, actual)
|
||||
|
||||
def test_dump_load_empty_table(self):
|
||||
# create empty table
|
||||
empty_table = self._load_parquet_files(
|
||||
["tests/data/arrow/data0.parquet"]
|
||||
).slice(length=0)
|
||||
empty_table = self._load_parquet_files(["tests/data/arrow/data0.parquet"]).slice(length=0)
|
||||
self.assertEqual(empty_table.num_rows, 0)
|
||||
# dump empty table
|
||||
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
@@ -94,9 +83,7 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
):
|
||||
with self.subTest(dataset_path=dataset_path):
|
||||
parquet_files = glob.glob(dataset_path)
|
||||
expected_num_rows = sum(
|
||||
parquet.read_metadata(file).num_rows for file in parquet_files
|
||||
)
|
||||
expected_num_rows = sum(parquet.read_metadata(file).num_rows for file in parquet_files)
|
||||
with build_batch_reader_from_files(
|
||||
parquet_files,
|
||||
batch_size=expected_num_rows,
|
||||
@@ -104,9 +91,7 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter:
|
||||
total_num_rows = 0
|
||||
for batch in concurrent_iter:
|
||||
print(
|
||||
f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}"
|
||||
)
|
||||
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}")
|
||||
self.assertLessEqual(batch.num_rows, expected_num_rows)
|
||||
total_num_rows += batch.num_rows
|
||||
self.assertEqual(total_num_rows, expected_num_rows)
|
||||
@@ -121,9 +106,7 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
table = self._load_parquet_files(parquet_files)
|
||||
total_num_rows = 0
|
||||
for batch in table.to_batches(max_chunksize=table.num_rows):
|
||||
print(
|
||||
f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}"
|
||||
)
|
||||
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}")
|
||||
self.assertLessEqual(batch.num_rows, table.num_rows)
|
||||
total_num_rows += batch.num_rows
|
||||
self.assertEqual(total_num_rows, table.num_rows)
|
||||
@@ -135,26 +118,14 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}")
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
self.assertTrue(
|
||||
dump_to_parquet_files(
|
||||
table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2
|
||||
)
|
||||
)
|
||||
parquet_files = glob.glob(
|
||||
os.path.join(output_dir, "arrow_schema_metadata*.parquet")
|
||||
)
|
||||
loaded_table = load_from_parquet_files(
|
||||
parquet_files, table.column_names[:1]
|
||||
)
|
||||
self.assertTrue(dump_to_parquet_files(table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2))
|
||||
parquet_files = glob.glob(os.path.join(output_dir, "arrow_schema_metadata*.parquet"))
|
||||
loaded_table = load_from_parquet_files(parquet_files, table.column_names[:1])
|
||||
print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}")
|
||||
self.assertEqual(
|
||||
table_with_meta.schema.metadata, loaded_table.schema.metadata
|
||||
)
|
||||
self.assertEqual(table_with_meta.schema.metadata, loaded_table.schema.metadata)
|
||||
with parquet.ParquetFile(parquet_files[0]) as file:
|
||||
print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}")
|
||||
self.assertEqual(
|
||||
table_with_meta.schema.metadata, file.schema_arrow.metadata
|
||||
)
|
||||
self.assertEqual(table_with_meta.schema.metadata, file.schema_arrow.metadata)
|
||||
|
||||
def test_load_mixed_string_types(self):
|
||||
parquet_paths = glob.glob("tests/data/arrow/*.parquet")
|
||||
@@ -166,9 +137,7 @@ class TestArrow(TestFabric, unittest.TestCase):
|
||||
loaded_table = load_from_parquet_files(parquet_paths)
|
||||
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
|
||||
batch_reader = build_batch_reader_from_files(parquet_paths)
|
||||
self.assertEqual(
|
||||
table.num_rows * 2, sum(batch.num_rows for batch in batch_reader)
|
||||
)
|
||||
self.assertEqual(table.num_rows * 2, sum(batch.num_rows for batch in batch_reader))
|
||||
|
||||
@logger.catch(reraise=True, message="failed to load parquet files")
|
||||
def _load_from_parquet_files_with_log(self, paths, columns):
|
||||
|
||||
@@ -45,9 +45,7 @@ class TestBench(TestFabric, unittest.TestCase):
|
||||
num_sort_partitions = 1 << 3
|
||||
for shuffle_engine in ("duckdb", "arrow"):
|
||||
for sort_engine in ("duckdb", "arrow", "polars"):
|
||||
with self.subTest(
|
||||
shuffle_engine=shuffle_engine, sort_engine=sort_engine
|
||||
):
|
||||
with self.subTest(shuffle_engine=shuffle_engine, sort_engine=sort_engine):
|
||||
ctx = Context()
|
||||
random_records = generate_random_records(
|
||||
ctx,
|
||||
|
||||
@@ -34,9 +34,7 @@ class TestCommon(TestFabric, unittest.TestCase):
|
||||
npartitions = data.draw(st.integers(1, 2 * nelements))
|
||||
items = list(range(nelements))
|
||||
computed = split_into_rows(items, npartitions)
|
||||
expected = [
|
||||
get_nth_partition(items, n, npartitions) for n in range(npartitions)
|
||||
]
|
||||
expected = [get_nth_partition(items, n, npartitions) for n in range(npartitions)]
|
||||
self.assertEqual(expected, computed)
|
||||
|
||||
@given(st.data())
|
||||
|
||||
@@ -70,9 +70,7 @@ def test_flat_map(sp: Session):
|
||||
|
||||
# user need to specify the schema if can not be inferred from the mapping values
|
||||
df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())]))
|
||||
assert df3.to_arrow() == pa.table(
|
||||
{"c": pa.array([None, None, None], type=pa.int64())}
|
||||
)
|
||||
assert df3.to_arrow() == pa.table({"c": pa.array([None, None, None], type=pa.int64())})
|
||||
|
||||
|
||||
def test_map_batches(sp: Session):
|
||||
@@ -99,10 +97,7 @@ def test_random_shuffle(sp: Session):
|
||||
assert sorted(shuffled) == list(range(1000))
|
||||
|
||||
def count_inversions(arr: List[int]) -> int:
|
||||
return sum(
|
||||
sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j])
|
||||
for i in range(len(arr))
|
||||
)
|
||||
return sum(sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) for i in range(len(arr)))
|
||||
|
||||
# check the shuffle is random enough
|
||||
# the expected number of inversions is n*(n-1)/4 = 249750
|
||||
@@ -158,9 +153,7 @@ def test_partial_sql(sp: Session):
|
||||
# join
|
||||
df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]}))
|
||||
df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]}))
|
||||
joined = sp.partial_sql(
|
||||
"select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2
|
||||
)
|
||||
joined = sp.partial_sql("select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2)
|
||||
assert joined.to_arrow() == pa.table(
|
||||
{"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]},
|
||||
schema=pa.schema(
|
||||
@@ -193,10 +186,7 @@ def test_unpicklable_task_exception(sp: Session):
|
||||
df.map(lambda x: logger.info("use outside logger")).to_arrow()
|
||||
except Exception as ex:
|
||||
assert "Can't pickle task" in str(ex)
|
||||
assert (
|
||||
"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
|
||||
in str(ex)
|
||||
)
|
||||
assert "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." in str(ex)
|
||||
else:
|
||||
assert False, "expected exception"
|
||||
|
||||
|
||||
@@ -30,9 +30,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
dataset = ParquetDataSet([os.path.join(self.output_root_abspath, "*.parquet")])
|
||||
self.assertEqual(num_urls, dataset.num_rows)
|
||||
|
||||
def _generate_parquet_dataset(
|
||||
self, output_path, npartitions, num_rows, row_group_size
|
||||
):
|
||||
def _generate_parquet_dataset(self, output_path, npartitions, num_rows, row_group_size):
|
||||
duckdb.sql(
|
||||
f"""copy (
|
||||
select range as i, range % {npartitions} as partition from range(0, {num_rows}) )
|
||||
@@ -41,9 +39,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
)
|
||||
return ParquetDataSet([f"{output_path}/**/*.parquet"])
|
||||
|
||||
def _check_partition_datasets(
|
||||
self, orig_dataset: ParquetDataSet, partition_func, npartition
|
||||
):
|
||||
def _check_partition_datasets(self, orig_dataset: ParquetDataSet, partition_func, npartition):
|
||||
# build partitioned datasets
|
||||
partitioned_datasets = partition_func(npartition)
|
||||
self.assertEqual(npartition, len(partitioned_datasets))
|
||||
@@ -52,9 +48,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
sum(dataset.num_rows for dataset in partitioned_datasets),
|
||||
)
|
||||
# load as arrow table
|
||||
loaded_table = arrow.concat_tables(
|
||||
[dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets]
|
||||
)
|
||||
loaded_table = arrow.concat_tables([dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets])
|
||||
self.assertEqual(orig_dataset.num_rows, loaded_table.num_rows)
|
||||
# compare arrow tables
|
||||
orig_table = orig_dataset.to_arrow_table(max_workers=1)
|
||||
@@ -74,9 +68,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
|
||||
def test_partition_by_files(self):
|
||||
output_path = os.path.join(self.output_root_abspath, "test_partition_by_files")
|
||||
orig_dataset = self._generate_parquet_dataset(
|
||||
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
|
||||
)
|
||||
orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000)
|
||||
num_files = len(orig_dataset.resolved_paths)
|
||||
for npartition in range(1, num_files + 1):
|
||||
for random_shuffle in (False, True):
|
||||
@@ -84,17 +76,13 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
|
||||
self._check_partition_datasets(
|
||||
orig_dataset,
|
||||
lambda n: orig_dataset.partition_by_files(
|
||||
n, random_shuffle=random_shuffle
|
||||
),
|
||||
lambda n: orig_dataset.partition_by_files(n, random_shuffle=random_shuffle),
|
||||
npartition,
|
||||
)
|
||||
|
||||
def test_partition_by_rows(self):
|
||||
output_path = os.path.join(self.output_root_abspath, "test_partition_by_rows")
|
||||
orig_dataset = self._generate_parquet_dataset(
|
||||
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
|
||||
)
|
||||
orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000)
|
||||
num_files = len(orig_dataset.resolved_paths)
|
||||
for npartition in range(1, 2 * num_files + 1):
|
||||
for random_shuffle in (False, True):
|
||||
@@ -102,9 +90,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
|
||||
self._check_partition_datasets(
|
||||
orig_dataset,
|
||||
lambda n: orig_dataset.partition_by_rows(
|
||||
n, random_shuffle=random_shuffle
|
||||
),
|
||||
lambda n: orig_dataset.partition_by_rows(n, random_shuffle=random_shuffle),
|
||||
npartition,
|
||||
)
|
||||
|
||||
@@ -116,9 +102,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
self.assertEqual(len(dataset.resolved_paths), len(filenames))
|
||||
|
||||
def test_paths_with_char_ranges(self):
|
||||
dataset_with_char_ranges = ParquetDataSet(
|
||||
["tests/data/arrow/data[0-9].parquet"]
|
||||
)
|
||||
dataset_with_char_ranges = ParquetDataSet(["tests/data/arrow/data[0-9].parquet"])
|
||||
dataset_with_wildcards = ParquetDataSet(["tests/data/arrow/*.parquet"])
|
||||
self.assertEqual(
|
||||
len(dataset_with_char_ranges.resolved_paths),
|
||||
@@ -126,9 +110,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_to_arrow_table_batch_reader(self):
|
||||
memdb = duckdb.connect(
|
||||
database=":memory:", config={"arrow_large_buffer_size": "true"}
|
||||
)
|
||||
memdb = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"})
|
||||
for dataset_path in (
|
||||
"tests/data/arrow/*.parquet",
|
||||
"tests/data/large_array/*.parquet",
|
||||
@@ -137,24 +119,14 @@ class TestDataSet(TestFabric, unittest.TestCase):
|
||||
print(f"dataset_path: {dataset_path}, conn: {conn}")
|
||||
with self.subTest(dataset_path=dataset_path, conn=conn):
|
||||
dataset = ParquetDataSet([dataset_path])
|
||||
to_batches = dataset.to_arrow_table(
|
||||
max_workers=1, conn=conn
|
||||
).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
|
||||
batch_reader = dataset.to_batch_reader(
|
||||
batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn
|
||||
)
|
||||
with ConcurrentIter(
|
||||
batch_reader, max_buffer_size=2
|
||||
) as batch_reader:
|
||||
to_batches = dataset.to_arrow_table(max_workers=1, conn=conn).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
|
||||
batch_reader = dataset.to_batch_reader(batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn)
|
||||
with ConcurrentIter(batch_reader, max_buffer_size=2) as batch_reader:
|
||||
for batch_iter in (to_batches, batch_reader):
|
||||
total_num_rows = 0
|
||||
for batch in batch_iter:
|
||||
print(
|
||||
f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}"
|
||||
)
|
||||
self.assertLessEqual(
|
||||
batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2
|
||||
)
|
||||
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}")
|
||||
self.assertLessEqual(batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2)
|
||||
total_num_rows += batch.num_rows
|
||||
print(f"{dataset_path}: total_num_rows {total_num_rows}")
|
||||
self.assertEqual(total_num_rows, dataset.num_rows)
|
||||
@@ -167,8 +139,6 @@ def test_arrow_reader(benchmark, reader: str, dataset_path: str):
|
||||
dataset = ParquetDataSet([dataset_path])
|
||||
conn = None
|
||||
if reader == "duckdb":
|
||||
conn = duckdb.connect(
|
||||
database=":memory:", config={"arrow_large_buffer_size": "true"}
|
||||
)
|
||||
conn = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"})
|
||||
benchmark(dataset.to_arrow_table, conn=conn)
|
||||
# result: arrow reader is 4x faster than duckdb reader in small dataset, 1.4x faster in large dataset
|
||||
|
||||
@@ -7,9 +7,7 @@ from smallpond.io.arrow import cast_columns_to_large_string
|
||||
from tests.test_fabric import TestFabric
|
||||
|
||||
|
||||
@unittest.skipUnless(
|
||||
importlib.util.find_spec("deltalake") is not None, "cannot find deltalake"
|
||||
)
|
||||
@unittest.skipUnless(importlib.util.find_spec("deltalake") is not None, "cannot find deltalake")
|
||||
class TestDeltaLake(TestFabric, unittest.TestCase):
|
||||
def test_read_write_deltalake(self):
|
||||
from deltalake import DeltaTable, write_deltalake
|
||||
@@ -20,9 +18,7 @@ class TestDeltaLake(TestFabric, unittest.TestCase):
|
||||
):
|
||||
parquet_files = glob.glob(dataset_path)
|
||||
expected = self._load_parquet_files(parquet_files)
|
||||
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
|
||||
dir=self.output_root_abspath
|
||||
) as output_dir:
|
||||
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
write_deltalake(output_dir, expected, large_dtypes=True)
|
||||
dt = DeltaTable(output_dir)
|
||||
self._compare_arrow_tables(expected, dt.to_pyarrow_table())
|
||||
@@ -35,12 +31,8 @@ class TestDeltaLake(TestFabric, unittest.TestCase):
|
||||
"tests/data/large_array/*.parquet",
|
||||
):
|
||||
parquet_files = glob.glob(dataset_path)
|
||||
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
|
||||
dir=self.output_root_abspath
|
||||
) as output_dir:
|
||||
table = cast_columns_to_large_string(
|
||||
self._load_parquet_files(parquet_files)
|
||||
)
|
||||
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
table = cast_columns_to_large_string(self._load_parquet_files(parquet_files))
|
||||
write_deltalake(output_dir, table, large_dtypes=True, mode="overwrite")
|
||||
write_deltalake(output_dir, table, large_dtypes=False, mode="append")
|
||||
loaded_table = DeltaTable(output_dir).to_pyarrow_table()
|
||||
|
||||
@@ -69,9 +69,7 @@ class OutputMsgPythonTask(PythonScriptTask):
|
||||
input_datasets: List[DataSet],
|
||||
output_path: str,
|
||||
) -> bool:
|
||||
logger.info(
|
||||
f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}"
|
||||
)
|
||||
logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}")
|
||||
self.inject_fault()
|
||||
return True
|
||||
|
||||
@@ -105,9 +103,7 @@ class CopyInputArrowNode(ArrowComputeNode):
|
||||
super().__init__(ctx, input_deps)
|
||||
self.msg = msg
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
|
||||
return copy_input_arrow(runtime_ctx, input_tables, self.msg)
|
||||
|
||||
|
||||
@@ -117,24 +113,18 @@ class CopyInputStreamNode(ArrowStreamNode):
|
||||
super().__init__(ctx, input_deps)
|
||||
self.msg = msg
|
||||
|
||||
def process(
|
||||
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
|
||||
) -> Iterable[arrow.Table]:
|
||||
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
|
||||
return copy_input_stream(runtime_ctx, input_readers, self.msg)
|
||||
|
||||
|
||||
def copy_input_arrow(
|
||||
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str
|
||||
) -> arrow.Table:
|
||||
def copy_input_arrow(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str) -> arrow.Table:
|
||||
logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}")
|
||||
time.sleep(runtime_ctx.secs_executor_probe_interval)
|
||||
runtime_ctx.task.inject_fault()
|
||||
return input_tables[0]
|
||||
|
||||
|
||||
def copy_input_stream(
|
||||
runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str
|
||||
) -> Iterable[arrow.Table]:
|
||||
def copy_input_stream(runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str) -> Iterable[arrow.Table]:
|
||||
for index, batch in enumerate(input_readers[0]):
|
||||
logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}")
|
||||
time.sleep(runtime_ctx.secs_executor_probe_interval)
|
||||
@@ -146,62 +136,44 @@ def copy_input_stream(
|
||||
runtime_ctx.task.inject_fault()
|
||||
|
||||
|
||||
def copy_input_batch(
|
||||
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str
|
||||
) -> arrow.Table:
|
||||
def copy_input_batch(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str) -> arrow.Table:
|
||||
logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}")
|
||||
time.sleep(runtime_ctx.secs_executor_probe_interval)
|
||||
runtime_ctx.task.inject_fault()
|
||||
return input_batches[0]
|
||||
|
||||
|
||||
def copy_input_data_frame(
|
||||
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
|
||||
) -> DataFrame:
|
||||
def copy_input_data_frame(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
|
||||
runtime_ctx.task.inject_fault()
|
||||
return input_dfs[0]
|
||||
|
||||
|
||||
def copy_input_data_frame_batch(
|
||||
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
|
||||
) -> DataFrame:
|
||||
def copy_input_data_frame_batch(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
|
||||
runtime_ctx.task.inject_fault()
|
||||
return input_dfs[0]
|
||||
|
||||
|
||||
def merge_input_tables(
|
||||
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def merge_input_tables(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]) -> arrow.Table:
|
||||
runtime_ctx.task.inject_fault()
|
||||
output = arrow.concat_tables(input_batches)
|
||||
logger.info(
|
||||
f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}"
|
||||
)
|
||||
logger.info(f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}")
|
||||
return output
|
||||
|
||||
|
||||
def merge_input_data_frames(
|
||||
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
|
||||
) -> DataFrame:
|
||||
def merge_input_data_frames(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
|
||||
runtime_ctx.task.inject_fault()
|
||||
output = pandas.concat(input_dfs)
|
||||
logger.info(
|
||||
f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}"
|
||||
)
|
||||
logger.info(f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}")
|
||||
return output
|
||||
|
||||
|
||||
def parse_url(
|
||||
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
|
||||
) -> arrow.Table:
|
||||
def parse_url(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
|
||||
urls = input_tables[0].columns[0]
|
||||
hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls]
|
||||
return input_tables[0].append_column("host", arrow.array(hosts))
|
||||
|
||||
|
||||
def nonzero_exit_code(
|
||||
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
|
||||
) -> bool:
|
||||
def nonzero_exit_code(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
|
||||
import sys
|
||||
|
||||
if runtime_ctx.task._memory_boost == 1:
|
||||
@@ -210,9 +182,7 @@ def nonzero_exit_code(
|
||||
|
||||
|
||||
# create an empty file with a fixed name
|
||||
def empty_file(
|
||||
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
|
||||
) -> bool:
|
||||
def empty_file(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
|
||||
import os
|
||||
|
||||
with open(os.path.join(output_path, "file"), "w") as fout:
|
||||
@@ -231,9 +201,7 @@ def split_url(urls: arrow.array) -> arrow.array:
|
||||
return arrow.array(url_parts, type=arrow.list_(arrow.string()))
|
||||
|
||||
|
||||
def choose_random_urls(
|
||||
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5
|
||||
) -> arrow.Table:
|
||||
def choose_random_urls(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5) -> arrow.Table:
|
||||
# get the current running task
|
||||
runtime_task = runtime_ctx.task
|
||||
# access task-specific attributes
|
||||
@@ -255,16 +223,12 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
def test_arrow_task(self):
|
||||
for use_duckdb_reader in (False, True):
|
||||
with self.subTest(use_duckdb_reader=use_duckdb_reader):
|
||||
with tempfile.TemporaryDirectory(
|
||||
dir=self.output_root_abspath
|
||||
) as output_dir:
|
||||
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_table = dataset.to_arrow_table()
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=7
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7)
|
||||
if use_duckdb_reader:
|
||||
data_partitions = ProjectionNode(
|
||||
ctx,
|
||||
@@ -274,9 +238,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
arrow_compute = ArrowComputeNode(
|
||||
ctx,
|
||||
(data_partitions,),
|
||||
process_func=functools.partial(
|
||||
copy_input_arrow, msg="arrow compute"
|
||||
),
|
||||
process_func=functools.partial(copy_input_arrow, msg="arrow compute"),
|
||||
use_duckdb_reader=use_duckdb_reader,
|
||||
output_name="arrow_compute",
|
||||
output_path=output_dir,
|
||||
@@ -285,9 +247,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
arrow_stream = ArrowStreamNode(
|
||||
ctx,
|
||||
(data_partitions,),
|
||||
process_func=functools.partial(
|
||||
copy_input_stream, msg="arrow stream"
|
||||
),
|
||||
process_func=functools.partial(copy_input_stream, msg="arrow stream"),
|
||||
streaming_batch_size=10,
|
||||
secs_checkpoint_interval=0.5,
|
||||
use_duckdb_reader=use_duckdb_reader,
|
||||
@@ -298,9 +258,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
arrow_batch = ArrowBatchNode(
|
||||
ctx,
|
||||
(data_partitions,),
|
||||
process_func=functools.partial(
|
||||
copy_input_batch, msg="arrow batch"
|
||||
),
|
||||
process_func=functools.partial(copy_input_batch, msg="arrow batch"),
|
||||
streaming_batch_size=10,
|
||||
secs_checkpoint_interval=0.5,
|
||||
use_duckdb_reader=use_duckdb_reader,
|
||||
@@ -314,12 +272,8 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
output_path=output_dir,
|
||||
)
|
||||
plan = LogicalPlan(ctx, data_sink)
|
||||
exec_plan = self.execute_plan(
|
||||
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
|
||||
)
|
||||
self.assertTrue(
|
||||
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
|
||||
)
|
||||
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
|
||||
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
|
||||
arrow_compute_output = ParquetDataSet(
|
||||
[os.path.join(output_dir, "arrow_compute", "**/*.parquet")],
|
||||
recursive=True,
|
||||
@@ -334,21 +288,15 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
)
|
||||
self._compare_arrow_tables(
|
||||
data_table,
|
||||
arrow_compute_output.to_arrow_table().select(
|
||||
data_table.column_names
|
||||
),
|
||||
arrow_compute_output.to_arrow_table().select(data_table.column_names),
|
||||
)
|
||||
self._compare_arrow_tables(
|
||||
data_table,
|
||||
arrow_stream_output.to_arrow_table().select(
|
||||
data_table.column_names
|
||||
),
|
||||
arrow_stream_output.to_arrow_table().select(data_table.column_names),
|
||||
)
|
||||
self._compare_arrow_tables(
|
||||
data_table,
|
||||
arrow_batch_output.to_arrow_table().select(
|
||||
data_table.column_names
|
||||
),
|
||||
arrow_batch_output.to_arrow_table().select(data_table.column_names),
|
||||
)
|
||||
|
||||
def test_pandas_task(self):
|
||||
@@ -376,16 +324,10 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
output_path=output_dir,
|
||||
cpu_limit=2,
|
||||
)
|
||||
data_sink = DataSinkNode(
|
||||
ctx, (pandas_compute, pandas_batch), output_path=output_dir
|
||||
)
|
||||
data_sink = DataSinkNode(ctx, (pandas_compute, pandas_batch), output_path=output_dir)
|
||||
plan = LogicalPlan(ctx, data_sink)
|
||||
exec_plan = self.execute_plan(
|
||||
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
|
||||
)
|
||||
self.assertTrue(
|
||||
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
|
||||
)
|
||||
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
|
||||
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
|
||||
pandas_compute_output = ParquetDataSet(
|
||||
[os.path.join(output_dir, "pandas_compute", "**/*.parquet")],
|
||||
recursive=True,
|
||||
@@ -394,21 +336,15 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
[os.path.join(output_dir, "pandas_batch", "**/*.parquet")],
|
||||
recursive=True,
|
||||
)
|
||||
self._compare_arrow_tables(
|
||||
data_table, pandas_compute_output.to_arrow_table()
|
||||
)
|
||||
self._compare_arrow_tables(data_table, pandas_compute_output.to_arrow_table())
|
||||
self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table())
|
||||
|
||||
def test_variable_length_input_datasets(self):
|
||||
ctx = Context()
|
||||
small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
|
||||
small_partitions = DataSetPartitionNode(
|
||||
ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7
|
||||
)
|
||||
large_partitions = DataSetPartitionNode(
|
||||
ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7
|
||||
)
|
||||
small_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7)
|
||||
large_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7)
|
||||
arrow_batch = ArrowBatchNode(
|
||||
ctx,
|
||||
(small_partitions, large_partitions),
|
||||
@@ -428,9 +364,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
cpu_limit=2,
|
||||
)
|
||||
plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch)))
|
||||
exec_plan = self.execute_plan(
|
||||
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
|
||||
)
|
||||
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
|
||||
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
|
||||
arrow_batch_output = ParquetDataSet(
|
||||
[os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")],
|
||||
@@ -440,9 +374,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
[os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")],
|
||||
recursive=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows
|
||||
)
|
||||
self.assertEqual(small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows)
|
||||
self.assertEqual(
|
||||
small_dataset.num_rows + large_dataset.num_rows,
|
||||
pandas_batch_output.num_rows,
|
||||
@@ -453,9 +385,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
# select columns when defining dataset
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=3, partition_by_rows=True
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3, partition_by_rows=True)
|
||||
# projection as input of arrow node
|
||||
generated_columns = ["filename", "file_row_number"]
|
||||
urls_with_host = ArrowComputeNode(
|
||||
@@ -480,9 +410,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
# unify different schemas
|
||||
merged_diff_schemas = ProjectionNode(
|
||||
ctx,
|
||||
DataSetPartitionNode(
|
||||
ctx, (distinct_urls_with_host, urls_with_host), npartitions=1
|
||||
),
|
||||
DataSetPartitionNode(ctx, (distinct_urls_with_host, urls_with_host), npartitions=1),
|
||||
union_by_name=True,
|
||||
)
|
||||
host_partitions = HashPartitionNode(
|
||||
@@ -513,9 +441,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=dataset.num_files
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
|
||||
ctx.create_function(
|
||||
"split_url",
|
||||
split_url,
|
||||
@@ -537,9 +463,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
npartitions = 1000
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions)
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_files,), npartitions=npartitions
|
||||
)
|
||||
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions)
|
||||
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
|
||||
plan = LogicalPlan(ctx, output_msg)
|
||||
self.execute_plan(
|
||||
@@ -552,13 +476,9 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
def test_many_producers_and_partitions(self):
|
||||
ctx = Context()
|
||||
npartitions = 10000
|
||||
dataset = ParquetDataSet(
|
||||
["tests/data/mock_urls/*.parquet"] * (npartitions * 10)
|
||||
)
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * (npartitions * 10))
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_files,), npartitions=npartitions, cpu_limit=1
|
||||
)
|
||||
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions, cpu_limit=1)
|
||||
data_partitions.max_num_producer_tasks = 20
|
||||
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
|
||||
plan = LogicalPlan(ctx, output_msg)
|
||||
@@ -573,12 +493,8 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=dataset.num_files
|
||||
)
|
||||
output_msg = OutputMsgPythonNode(
|
||||
ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
|
||||
output_msg = OutputMsgPythonNode(ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5)
|
||||
plan = LogicalPlan(ctx, output_msg)
|
||||
runtime_ctx = RuntimeContext(
|
||||
JobId.new(),
|
||||
@@ -596,9 +512,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello")
|
||||
copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello")
|
||||
output_msg = OutputMsgPythonNode2(
|
||||
ctx, (copy_input_arrow_node, copy_input_stream_node), "hello"
|
||||
)
|
||||
output_msg = OutputMsgPythonNode2(ctx, (copy_input_arrow_node, copy_input_stream_node), "hello")
|
||||
plan = LogicalPlan(ctx, output_msg)
|
||||
self.execute_plan(plan)
|
||||
|
||||
@@ -606,9 +520,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
uniq_urls = SqlEngineNode(
|
||||
ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB
|
||||
)
|
||||
uniq_urls = SqlEngineNode(ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB)
|
||||
uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2)
|
||||
uniq_url_count = SqlEngineNode(
|
||||
ctx,
|
||||
@@ -637,9 +549,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
memory_limit=1 * GB,
|
||||
)
|
||||
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
data_sink = DataSinkNode(
|
||||
ctx, (arrow_compute, arrow_stream), output_path=output_dir
|
||||
)
|
||||
data_sink = DataSinkNode(ctx, (arrow_compute, arrow_stream), output_path=output_dir)
|
||||
plan = LogicalPlan(ctx, data_sink)
|
||||
self.execute_plan(
|
||||
plan,
|
||||
@@ -652,17 +562,11 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
nonzero_exitcode = PythonScriptNode(
|
||||
ctx, (data_files,), process_func=nonzero_exit_code
|
||||
)
|
||||
nonzero_exitcode = PythonScriptNode(ctx, (data_files,), process_func=nonzero_exit_code)
|
||||
plan = LogicalPlan(ctx, nonzero_exitcode)
|
||||
exec_plan = self.execute_plan(
|
||||
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False
|
||||
)
|
||||
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False)
|
||||
self.assertFalse(exec_plan.successful)
|
||||
exec_plan = self.execute_plan(
|
||||
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True
|
||||
)
|
||||
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True)
|
||||
self.assertTrue(exec_plan.successful)
|
||||
|
||||
def test_manifest_only_data_sink(self):
|
||||
@@ -675,9 +579,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
dataset = ParquetDataSet(filenames)
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512)
|
||||
data_sink = DataSinkNode(
|
||||
ctx, (data_partitions,), output_path=output_dir, manifest_only=True
|
||||
)
|
||||
data_sink = DataSinkNode(ctx, (data_partitions,), output_path=output_dir, manifest_only=True)
|
||||
plan = LogicalPlan(ctx, data_sink)
|
||||
self.execute_plan(plan)
|
||||
|
||||
@@ -734,12 +636,8 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
|
||||
url_counts = SqlEngineNode(
|
||||
ctx, (data_partitions,), r"select count(url) as cnt from {0}"
|
||||
)
|
||||
distinct_url_counts = SqlEngineNode(
|
||||
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
|
||||
)
|
||||
url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(url) as cnt from {0}")
|
||||
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
|
||||
merged_counts = DataSetPartitionNode(
|
||||
ctx,
|
||||
(
|
||||
@@ -764,9 +662,7 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
r"select count(url) as cnt from {0}",
|
||||
output_name="url_counts",
|
||||
)
|
||||
distinct_url_counts = SqlEngineNode(
|
||||
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
|
||||
)
|
||||
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
|
||||
merged_counts = DataSetPartitionNode(
|
||||
ctx,
|
||||
(
|
||||
@@ -787,24 +683,14 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
|
||||
empty_files1 = PythonScriptNode(
|
||||
ctx, (data_partitions,), process_func=empty_file
|
||||
)
|
||||
empty_files2 = PythonScriptNode(
|
||||
ctx, (data_partitions,), process_func=empty_file
|
||||
)
|
||||
empty_files1 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
|
||||
empty_files2 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
|
||||
link_path = os.path.join(self.runtime_ctx.output_root, "link")
|
||||
copy_path = os.path.join(self.runtime_ctx.output_root, "copy")
|
||||
copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input")
|
||||
data_link = DataSinkNode(
|
||||
ctx, (empty_files1, empty_files2), type="link", output_path=link_path
|
||||
)
|
||||
data_copy = DataSinkNode(
|
||||
ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path
|
||||
)
|
||||
data_copy_input = DataSinkNode(
|
||||
ctx, (data_partitions,), type="copy", output_path=copy_input_path
|
||||
)
|
||||
data_link = DataSinkNode(ctx, (empty_files1, empty_files2), type="link", output_path=link_path)
|
||||
data_copy = DataSinkNode(ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path)
|
||||
data_copy_input = DataSinkNode(ctx, (data_partitions,), type="copy", output_path=copy_input_path)
|
||||
plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input)))
|
||||
|
||||
self.execute_plan(plan)
|
||||
@@ -813,33 +699,19 @@ class TestExecution(TestFabric, unittest.TestCase):
|
||||
self.assertEqual(21, len(os.listdir(copy_path)))
|
||||
# file name should not be modified if no conflict
|
||||
self.assertEqual(
|
||||
set(
|
||||
filename
|
||||
for filename in os.listdir("tests/data/mock_urls")
|
||||
if filename.endswith(".parquet")
|
||||
),
|
||||
set(
|
||||
filename
|
||||
for filename in os.listdir(copy_input_path)
|
||||
if filename.endswith(".parquet")
|
||||
),
|
||||
set(filename for filename in os.listdir("tests/data/mock_urls") if filename.endswith(".parquet")),
|
||||
set(filename for filename in os.listdir(copy_input_path) if filename.endswith(".parquet")),
|
||||
)
|
||||
|
||||
def test_literal_datasets_as_data_sources(self):
|
||||
ctx = Context()
|
||||
num_rows = 10
|
||||
query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)")
|
||||
table_dataset = ArrowTableDataSet(
|
||||
arrow.Table.from_arrays([list(range(num_rows))], names=["i"])
|
||||
)
|
||||
table_dataset = ArrowTableDataSet(arrow.Table.from_arrays([list(range(num_rows))], names=["i"]))
|
||||
query_source = DataSourceNode(ctx, query_dataset)
|
||||
table_source = DataSourceNode(ctx, table_dataset)
|
||||
query_partitions = DataSetPartitionNode(
|
||||
ctx, (query_source,), npartitions=num_rows, partition_by_rows=True
|
||||
)
|
||||
table_partitions = DataSetPartitionNode(
|
||||
ctx, (table_source,), npartitions=num_rows, partition_by_rows=True
|
||||
)
|
||||
query_partitions = DataSetPartitionNode(ctx, (query_source,), npartitions=num_rows, partition_by_rows=True)
|
||||
table_partitions = DataSetPartitionNode(ctx, (table_source,), npartitions=num_rows, partition_by_rows=True)
|
||||
joined_rows = SqlEngineNode(
|
||||
ctx,
|
||||
(query_partitions, table_partitions),
|
||||
|
||||
@@ -27,9 +27,7 @@ from tests.datagen import generate_data
|
||||
generate_data()
|
||||
|
||||
|
||||
def run_scheduler(
|
||||
runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue
|
||||
):
|
||||
def run_scheduler(runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue):
|
||||
runtime_ctx.initialize("scheduler")
|
||||
scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue)))
|
||||
retval = scheduler.run()
|
||||
@@ -130,9 +128,7 @@ class TestFabric(unittest.TestCase):
|
||||
process.kill()
|
||||
process.join()
|
||||
|
||||
logger.info(
|
||||
f"#{i} process {process.name} exited with code {process.exitcode}"
|
||||
)
|
||||
logger.info(f"#{i} process {process.name} exited with code {process.exitcode}")
|
||||
|
||||
def start_execution(
|
||||
self,
|
||||
@@ -189,11 +185,7 @@ class TestFabric(unittest.TestCase):
|
||||
secs_wq_poll_interval=secs_wq_poll_interval,
|
||||
secs_executor_probe_interval=secs_executor_probe_interval,
|
||||
max_num_missed_probes=max_num_missed_probes,
|
||||
fault_inject_prob=(
|
||||
fault_inject_prob
|
||||
if fault_inject_prob is not None
|
||||
else self.fault_inject_prob
|
||||
),
|
||||
fault_inject_prob=(fault_inject_prob if fault_inject_prob is not None else self.fault_inject_prob),
|
||||
enable_profiling=enable_profiling,
|
||||
enable_diagnostic_metrics=enable_diagnostic_metrics,
|
||||
remove_empty_parquet=remove_empty_parquet,
|
||||
@@ -217,9 +209,7 @@ class TestFabric(unittest.TestCase):
|
||||
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
|
||||
)
|
||||
self.latest_state = scheduler
|
||||
self.executors = [
|
||||
Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)
|
||||
]
|
||||
self.executors = [Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)]
|
||||
self.processes = [
|
||||
Process(
|
||||
target=run_scheduler,
|
||||
@@ -229,10 +219,7 @@ class TestFabric(unittest.TestCase):
|
||||
name="scheduler",
|
||||
)
|
||||
]
|
||||
self.processes += [
|
||||
Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id)
|
||||
for executor in self.executors
|
||||
]
|
||||
self.processes += [Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id) for executor in self.executors]
|
||||
|
||||
for process in reversed(self.processes):
|
||||
process.start()
|
||||
@@ -264,15 +251,9 @@ class TestFabric(unittest.TestCase):
|
||||
self.assertTrue(latest_state.success)
|
||||
return latest_state.exec_plan
|
||||
|
||||
def _load_parquet_files(
|
||||
self, paths, filesystem: fsspec.AbstractFileSystem = None
|
||||
) -> arrow.Table:
|
||||
def _load_parquet_files(self, paths, filesystem: fsspec.AbstractFileSystem = None) -> arrow.Table:
|
||||
def read_parquet_file(path):
|
||||
return arrow.Table.from_batches(
|
||||
parquet.ParquetFile(
|
||||
path, buffer_size=16 * MB, filesystem=filesystem
|
||||
).iter_batches()
|
||||
)
|
||||
return arrow.Table.from_batches(parquet.ParquetFile(path, buffer_size=16 * MB, filesystem=filesystem).iter_batches())
|
||||
|
||||
with ThreadPoolExecutor(16) as pool:
|
||||
return arrow.concat_tables(pool.map(read_parquet_file, paths))
|
||||
|
||||
@@ -17,9 +17,7 @@ class TestFilesystem(TestFabric, unittest.TestCase):
|
||||
|
||||
def test_pickle_trace(self):
|
||||
with self.assertRaises(TypeError) as context:
|
||||
with tempfile.TemporaryDirectory(
|
||||
dir=self.output_root_abspath
|
||||
) as output_dir:
|
||||
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
||||
thread = threading.Thread()
|
||||
pickle_path = os.path.join(output_dir, "thread.pickle")
|
||||
dump(thread, pickle_path)
|
||||
|
||||
@@ -21,19 +21,11 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
|
||||
def test_join_chunkmeta_inodes(self):
|
||||
ctx = Context()
|
||||
|
||||
chunkmeta_dump = DataSourceNode(
|
||||
ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"])
|
||||
)
|
||||
chunkmeta_partitions = HashPartitionNode(
|
||||
ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"]
|
||||
)
|
||||
chunkmeta_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"]))
|
||||
chunkmeta_partitions = HashPartitionNode(ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"])
|
||||
|
||||
inodes_dump = DataSourceNode(
|
||||
ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"])
|
||||
)
|
||||
inodes_partitions = HashPartitionNode(
|
||||
ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"]
|
||||
)
|
||||
inodes_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"]))
|
||||
inodes_partitions = HashPartitionNode(ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"])
|
||||
|
||||
num_gc_chunks = SqlEngineNode(
|
||||
ctx,
|
||||
@@ -53,12 +45,8 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_source = DataSourceNode(ctx, parquet_dataset)
|
||||
partition_dim_a = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
|
||||
)
|
||||
partition_dim_b = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B"
|
||||
)
|
||||
partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A")
|
||||
partition_dim_b = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B")
|
||||
join_two_inputs = SqlEngineNode(
|
||||
ctx,
|
||||
(partition_dim_a, partition_dim_b),
|
||||
@@ -73,9 +61,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_source = DataSourceNode(ctx, parquet_dataset)
|
||||
partition_dim_a = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
|
||||
)
|
||||
partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A")
|
||||
partition_dim_a2 = EvenlyDistributedPartitionNode(
|
||||
ctx,
|
||||
(data_source,),
|
||||
@@ -94,9 +80,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
|
||||
)
|
||||
plan = LogicalPlan(
|
||||
ctx,
|
||||
DataSetPartitionNode(
|
||||
ctx, (join_two_inputs1, join_two_inputs2), npartitions=1
|
||||
),
|
||||
DataSetPartitionNode(ctx, (join_two_inputs1, join_two_inputs2), npartitions=1),
|
||||
)
|
||||
logger.info(str(plan))
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
|
||||
@@ -31,9 +31,7 @@ from tests.test_fabric import TestFabric
|
||||
|
||||
class CalculatePartitionFromFilename(UserDefinedPartitionNode):
|
||||
def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]:
|
||||
partitioned_datasets: List[ParquetDataSet] = [
|
||||
ParquetDataSet([]) for _ in range(self.npartitions)
|
||||
]
|
||||
partitioned_datasets: List[ParquetDataSet] = [ParquetDataSet([]) for _ in range(self.npartitions)]
|
||||
for path in dataset.resolved_paths:
|
||||
partition_idx = hash(path) % self.npartitions
|
||||
partitioned_datasets[partition_idx].paths.append(path)
|
||||
@@ -45,9 +43,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=dataset.num_files
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
|
||||
count_rows = SqlEngineNode(
|
||||
ctx,
|
||||
(data_partitions,),
|
||||
@@ -62,9 +58,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True)
|
||||
count_rows = SqlEngineNode(
|
||||
ctx,
|
||||
(data_partitions,),
|
||||
@@ -74,18 +68,14 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
)
|
||||
plan = LogicalPlan(ctx, count_rows)
|
||||
exec_plan = self.execute_plan(plan, num_executors=5)
|
||||
self.assertEqual(
|
||||
exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows
|
||||
)
|
||||
self.assertEqual(exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows)
|
||||
|
||||
def test_empty_dataset_partition(self):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
# create more partitions than files
|
||||
data_partitions = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_files,), npartitions=dataset.num_files * 2
|
||||
)
|
||||
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=dataset.num_files * 2)
|
||||
data_partitions.max_num_producer_tasks = 3
|
||||
unique_urls = SqlEngineNode(
|
||||
ctx,
|
||||
@@ -95,9 +85,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
memory_limit=1 * GB,
|
||||
)
|
||||
# nested partition
|
||||
nested_partitioned_urls = EvenlyDistributedPartitionNode(
|
||||
ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True
|
||||
)
|
||||
nested_partitioned_urls = EvenlyDistributedPartitionNode(ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True)
|
||||
parsed_urls = ArrowComputeNode(
|
||||
ctx,
|
||||
(nested_partitioned_urls,),
|
||||
@@ -106,18 +94,14 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
memory_limit=1 * GB,
|
||||
)
|
||||
plan = LogicalPlan(ctx, parsed_urls)
|
||||
final_output = self.execute_plan(
|
||||
plan, remove_empty_parquet=True, skip_task_with_empty_input=True
|
||||
).final_output
|
||||
final_output = self.execute_plan(plan, remove_empty_parquet=True, skip_task_with_empty_input=True).final_output
|
||||
self.assertTrue(isinstance(final_output, ParquetDataSet))
|
||||
self.assertEqual(dataset.num_rows, final_output.num_rows)
|
||||
|
||||
def test_hash_partition(self):
|
||||
for engine_type in ("duckdb", "arrow"):
|
||||
for partition_by_rows in (False, True):
|
||||
for hive_partitioning in (
|
||||
(False, True) if engine_type == "duckdb" else (False,)
|
||||
):
|
||||
for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,):
|
||||
with self.subTest(
|
||||
engine_type=engine_type,
|
||||
partition_by_rows=partition_by_rows,
|
||||
@@ -155,26 +139,16 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
exec_plan = self.execute_plan(plan)
|
||||
self.assertEqual(
|
||||
dataset.num_rows,
|
||||
pc.sum(
|
||||
exec_plan.final_output.to_arrow_table().column(
|
||||
"row_count"
|
||||
)
|
||||
).as_py(),
|
||||
pc.sum(exec_plan.final_output.to_arrow_table().column("row_count")).as_py(),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(exec_plan.final_output.load_partitioned_datasets(npartitions, DATA_PARTITION_COLUMN_NAME)),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.final_output.load_partitioned_datasets(
|
||||
npartitions, DATA_PARTITION_COLUMN_NAME
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.get_output(
|
||||
"hash_partitions"
|
||||
).load_partitioned_datasets(
|
||||
exec_plan.get_output("hash_partitions").load_partitioned_datasets(
|
||||
npartitions,
|
||||
DATA_PARTITION_COLUMN_NAME,
|
||||
hive_partitioning,
|
||||
@@ -185,9 +159,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
def test_empty_hash_partition(self):
|
||||
for engine_type in ("duckdb", "arrow"):
|
||||
for partition_by_rows in (False, True):
|
||||
for hive_partitioning in (
|
||||
(False, True) if engine_type == "duckdb" else (False,)
|
||||
):
|
||||
for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,):
|
||||
with self.subTest(
|
||||
engine_type=engine_type,
|
||||
partition_by_rows=partition_by_rows,
|
||||
@@ -199,9 +171,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
npartitions = 3
|
||||
npartitions_nested = 4
|
||||
num_rows = 1
|
||||
head_rows = SqlEngineNode(
|
||||
ctx, (data_files,), f"select * from {{0}} limit {num_rows}"
|
||||
)
|
||||
head_rows = SqlEngineNode(ctx, (data_files,), f"select * from {{0}} limit {num_rows}")
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx,
|
||||
(head_rows,),
|
||||
@@ -241,53 +211,31 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
memory_limit=1 * GB,
|
||||
)
|
||||
plan = LogicalPlan(ctx, select_every_row)
|
||||
exec_plan = self.execute_plan(
|
||||
plan, skip_task_with_empty_input=True
|
||||
)
|
||||
exec_plan = self.execute_plan(plan, skip_task_with_empty_input=True)
|
||||
self.assertEqual(num_rows, exec_plan.final_output.num_rows)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.final_output.load_partitioned_datasets(
|
||||
npartitions, "hash_partitions"
|
||||
)
|
||||
),
|
||||
len(exec_plan.final_output.load_partitioned_datasets(npartitions, "hash_partitions")),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions_nested,
|
||||
len(
|
||||
exec_plan.final_output.load_partitioned_datasets(
|
||||
npartitions_nested, "nested_hash_partitions"
|
||||
)
|
||||
),
|
||||
len(exec_plan.final_output.load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.get_output(
|
||||
"hash_partitions"
|
||||
).load_partitioned_datasets(
|
||||
npartitions, "hash_partitions"
|
||||
)
|
||||
),
|
||||
len(exec_plan.get_output("hash_partitions").load_partitioned_datasets(npartitions, "hash_partitions")),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions_nested,
|
||||
len(
|
||||
exec_plan.get_output(
|
||||
"nested_hash_partitions"
|
||||
).load_partitioned_datasets(
|
||||
npartitions_nested, "nested_hash_partitions"
|
||||
)
|
||||
exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")
|
||||
),
|
||||
)
|
||||
if hive_partitioning:
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.get_output(
|
||||
"hash_partitions"
|
||||
).load_partitioned_datasets(
|
||||
exec_plan.get_output("hash_partitions").load_partitioned_datasets(
|
||||
npartitions,
|
||||
"hash_partitions",
|
||||
hive_partitioning=True,
|
||||
@@ -297,9 +245,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
self.assertEqual(
|
||||
npartitions_nested,
|
||||
len(
|
||||
exec_plan.get_output(
|
||||
"nested_hash_partitions"
|
||||
).load_partitioned_datasets(
|
||||
exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(
|
||||
npartitions_nested,
|
||||
"nested_hash_partitions",
|
||||
hive_partitioning=True,
|
||||
@@ -341,19 +287,11 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
exec_plan = self.execute_plan(plan)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.final_output.load_partitioned_datasets(
|
||||
npartitions, data_partition_column
|
||||
)
|
||||
),
|
||||
len(exec_plan.final_output.load_partitioned_datasets(npartitions, data_partition_column)),
|
||||
)
|
||||
self.assertEqual(
|
||||
npartitions,
|
||||
len(
|
||||
exec_plan.get_output("input_partitions").load_partitioned_datasets(
|
||||
npartitions, data_partition_column, hive_partitioning
|
||||
)
|
||||
),
|
||||
len(exec_plan.get_output("input_partitions").load_partitioned_datasets(npartitions, data_partition_column, hive_partitioning)),
|
||||
)
|
||||
return exec_plan
|
||||
|
||||
@@ -376,12 +314,8 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
)
|
||||
|
||||
ctx = Context()
|
||||
output1 = DataSourceNode(
|
||||
ctx, dataset=exec_plan1.get_output("input_partitions")
|
||||
)
|
||||
output2 = DataSourceNode(
|
||||
ctx, dataset=exec_plan2.get_output("input_partitions")
|
||||
)
|
||||
output1 = DataSourceNode(ctx, dataset=exec_plan1.get_output("input_partitions"))
|
||||
output2 = DataSourceNode(ctx, dataset=exec_plan2.get_output("input_partitions"))
|
||||
split_urls1 = LoadPartitionedDataSetNode(
|
||||
ctx,
|
||||
(output1,),
|
||||
@@ -411,16 +345,8 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
plan = LogicalPlan(ctx, split_urls3)
|
||||
exec_plan3 = self.execute_plan(plan)
|
||||
# load each partition as arrow table and compare
|
||||
final_output_partitions1 = (
|
||||
exec_plan1.final_output.load_partitioned_datasets(
|
||||
npartitions, data_partition_column
|
||||
)
|
||||
)
|
||||
final_output_partitions3 = (
|
||||
exec_plan3.final_output.load_partitioned_datasets(
|
||||
npartitions, data_partition_column
|
||||
)
|
||||
)
|
||||
final_output_partitions1 = exec_plan1.final_output.load_partitioned_datasets(npartitions, data_partition_column)
|
||||
final_output_partitions3 = exec_plan3.final_output.load_partitioned_datasets(npartitions, data_partition_column)
|
||||
self.assertEqual(npartitions, len(final_output_partitions3))
|
||||
for x, y in zip(final_output_partitions1, final_output_partitions3):
|
||||
self._compare_arrow_tables(x.to_arrow_table(), y.to_arrow_table())
|
||||
@@ -433,9 +359,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
SqlEngineNode.default_cpu_limit = 1
|
||||
SqlEngineNode.default_memory_limit = 1 * GB
|
||||
initial_reduce = r"select host, count(*) as cnt from {0} group by host"
|
||||
combine_reduce_results = (
|
||||
r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
|
||||
)
|
||||
combine_reduce_results = r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
|
||||
join_query = r"select host, cnt from {0} where (exists (select * from {1} where {1}.host = {0}.host)) and (exists (select * from {2} where {2}.host = {0}.host))"
|
||||
|
||||
partition_by_hosts = HashPartitionNode(
|
||||
@@ -496,11 +420,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
url_count_by_3dims = SqlEngineNode(ctx, (partitioned_3dims,), initial_reduce)
|
||||
url_count_by_hosts_x_urls2 = SqlEngineNode(
|
||||
ctx,
|
||||
(
|
||||
ConsolidateNode(
|
||||
ctx, url_count_by_3dims, ["host_partition", "url_partition"]
|
||||
),
|
||||
),
|
||||
(ConsolidateNode(ctx, url_count_by_3dims, ["host_partition", "url_partition"]),),
|
||||
combine_reduce_results,
|
||||
output_name="url_count_by_hosts_x_urls2",
|
||||
)
|
||||
@@ -524,9 +444,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
output_name="join_count_by_hosts_x_urls2",
|
||||
)
|
||||
|
||||
union_url_count_by_hosts = UnionNode(
|
||||
ctx, (url_count_by_hosts1, url_count_by_hosts2)
|
||||
)
|
||||
union_url_count_by_hosts = UnionNode(ctx, (url_count_by_hosts1, url_count_by_hosts2))
|
||||
union_url_count_by_hosts_x_urls = UnionNode(
|
||||
ctx,
|
||||
(
|
||||
@@ -576,18 +494,14 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_source = DataSourceNode(ctx, parquet_files)
|
||||
file_partitions1 = CalculatePartitionFromFilename(
|
||||
ctx, (data_source,), npartitions=3, dimension="by_filename_hash1"
|
||||
)
|
||||
file_partitions1 = CalculatePartitionFromFilename(ctx, (data_source,), npartitions=3, dimension="by_filename_hash1")
|
||||
url_count1 = SqlEngineNode(
|
||||
ctx,
|
||||
(file_partitions1,),
|
||||
r"select host, count(*) as cnt from {0} group by host",
|
||||
output_name="url_count1",
|
||||
)
|
||||
file_partitions2 = CalculatePartitionFromFilename(
|
||||
ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2"
|
||||
)
|
||||
file_partitions2 = CalculatePartitionFromFilename(ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2")
|
||||
url_count2 = SqlEngineNode(
|
||||
ctx,
|
||||
(file_partitions2,),
|
||||
@@ -606,9 +520,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_source = DataSourceNode(ctx, parquet_dataset)
|
||||
evenly_dist_data_source = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_source,), npartitions=parquet_dataset.num_files
|
||||
)
|
||||
evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files)
|
||||
|
||||
parquet_datasets = [ParquetDataSet([p]) for p in parquet_dataset.resolved_paths]
|
||||
partitioned_data_source = UserPartitionedDataSourceNode(ctx, parquet_datasets)
|
||||
@@ -631,9 +543,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
memory_limit=1 * GB,
|
||||
)
|
||||
|
||||
plan = LogicalPlan(
|
||||
ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2])
|
||||
)
|
||||
plan = LogicalPlan(ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2]))
|
||||
exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True)
|
||||
self._compare_arrow_tables(
|
||||
exec_plan.get_output("url_count_by_host1").to_arrow_table(),
|
||||
@@ -647,9 +557,7 @@ class TestPartition(TestFabric, unittest.TestCase):
|
||||
ctx = Context()
|
||||
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_source = DataSourceNode(ctx, parquet_dataset)
|
||||
evenly_dist_data_source = EvenlyDistributedPartitionNode(
|
||||
ctx, (data_source,), npartitions=parquet_dataset.num_files
|
||||
)
|
||||
evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files)
|
||||
sql_query = SqlEngineNode(
|
||||
ctx,
|
||||
(evenly_dist_data_source,),
|
||||
|
||||
@@ -65,6 +65,4 @@ def test_fstest(sp: Session):
|
||||
|
||||
|
||||
def test_sort_mock_urls_v2(sp: Session):
|
||||
sort_mock_urls_v2(
|
||||
sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3
|
||||
)
|
||||
sort_mock_urls_v2(sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3)
|
||||
|
||||
@@ -21,9 +21,7 @@ from tests.test_fabric import TestFabric
|
||||
|
||||
|
||||
class RandomSleepTask(PythonScriptTask):
|
||||
def __init__(
|
||||
self, *args, sleep_secs: float, fail_first_try: bool, **kwargs
|
||||
) -> None:
|
||||
def __init__(self, *args, sleep_secs: float, fail_first_try: bool, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sleep_secs = sleep_secs
|
||||
self.fail_first_try = fail_first_try
|
||||
@@ -58,24 +56,16 @@ class RandomSleepNode(PythonScriptNode):
|
||||
self.fail_first_try = fail_first_try
|
||||
|
||||
def spawn(self, *args, **kwargs) -> RandomSleepTask:
|
||||
sleep_secs = (
|
||||
random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
|
||||
)
|
||||
return RandomSleepTask(
|
||||
*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try
|
||||
)
|
||||
sleep_secs = random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
|
||||
return RandomSleepTask(*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try)
|
||||
|
||||
|
||||
class TestScheduler(TestFabric, unittest.TestCase):
|
||||
def create_random_sleep_plan(
|
||||
self, npartitions, max_sleep_secs, fail_first_try=False
|
||||
):
|
||||
def create_random_sleep_plan(self, npartitions, max_sleep_secs, fail_first_try=False):
|
||||
ctx = Context()
|
||||
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=npartitions, partition_by_rows=True
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions, partition_by_rows=True)
|
||||
random_sleep = RandomSleepNode(
|
||||
ctx,
|
||||
(data_partitions,),
|
||||
@@ -87,13 +77,8 @@ class TestScheduler(TestFabric, unittest.TestCase):
|
||||
def check_executor_state(self, target_state: ExecutorState, nloops=200):
|
||||
for _ in range(nloops):
|
||||
latest_sched_state = self.get_latest_sched_state()
|
||||
if any(
|
||||
executor.state == target_state
|
||||
for executor in latest_sched_state.remote_executors
|
||||
):
|
||||
logger.info(
|
||||
f"found {target_state} executor in: {latest_sched_state.remote_executors}"
|
||||
)
|
||||
if any(executor.state == target_state for executor in latest_sched_state.remote_executors):
|
||||
logger.info(f"found {target_state} executor in: {latest_sched_state.remote_executors}")
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
@@ -121,9 +106,7 @@ class TestScheduler(TestFabric, unittest.TestCase):
|
||||
latest_sched_state = self.get_latest_sched_state()
|
||||
self.check_executor_state(ExecutorState.GOOD)
|
||||
|
||||
for i, (executor, process) in enumerate(
|
||||
random.sample(list(zip(executors, processes[1:])), k=num_fail)
|
||||
):
|
||||
for i, (executor, process) in enumerate(random.sample(list(zip(executors, processes[1:])), k=num_fail)):
|
||||
if i % 2 == 0:
|
||||
logger.warning(f"kill executor: {executor}")
|
||||
process.kill()
|
||||
@@ -165,9 +148,7 @@ class TestScheduler(TestFabric, unittest.TestCase):
|
||||
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
|
||||
|
||||
def test_stop_executor_on_failure(self):
|
||||
plan = self.create_random_sleep_plan(
|
||||
npartitions=3, max_sleep_secs=5, fail_first_try=True
|
||||
)
|
||||
plan = self.create_random_sleep_plan(npartitions=3, max_sleep_secs=5, fail_first_try=True)
|
||||
exec_plan = self.execute_plan(
|
||||
plan,
|
||||
num_executors=5,
|
||||
|
||||
@@ -5,9 +5,7 @@ from smallpond.dataframe import Session
|
||||
|
||||
def test_shutdown_cleanup(sp: Session):
|
||||
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should exist"
|
||||
assert os.path.exists(
|
||||
sp._runtime_ctx.staging_root
|
||||
), "staging directory should exist"
|
||||
assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should exist"
|
||||
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should exist"
|
||||
|
||||
# create some tasks and complete them
|
||||
@@ -16,15 +14,9 @@ def test_shutdown_cleanup(sp: Session):
|
||||
sp.shutdown()
|
||||
|
||||
# shutdown should clean up directories
|
||||
assert not os.path.exists(
|
||||
sp._runtime_ctx.queue_root
|
||||
), "queue directory should be cleared"
|
||||
assert not os.path.exists(
|
||||
sp._runtime_ctx.staging_root
|
||||
), "staging directory should be cleared"
|
||||
assert not os.path.exists(
|
||||
sp._runtime_ctx.temp_root
|
||||
), "temp directory should be cleared"
|
||||
assert not os.path.exists(sp._runtime_ctx.queue_root), "queue directory should be cleared"
|
||||
assert not os.path.exists(sp._runtime_ctx.staging_root), "staging directory should be cleared"
|
||||
assert not os.path.exists(sp._runtime_ctx.temp_root), "temp directory should be cleared"
|
||||
with open(sp._runtime_ctx.job_status_path) as fin:
|
||||
assert "success" in fin.read(), "job status should be success"
|
||||
|
||||
@@ -41,14 +33,8 @@ def test_shutdown_no_cleanup_on_failure(sp: Session):
|
||||
sp.shutdown()
|
||||
|
||||
# shutdown should not clean up directories
|
||||
assert os.path.exists(
|
||||
sp._runtime_ctx.queue_root
|
||||
), "queue directory should not be cleared"
|
||||
assert os.path.exists(
|
||||
sp._runtime_ctx.staging_root
|
||||
), "staging directory should not be cleared"
|
||||
assert os.path.exists(
|
||||
sp._runtime_ctx.temp_root
|
||||
), "temp directory should not be cleared"
|
||||
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should not be cleared"
|
||||
assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should not be cleared"
|
||||
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should not be cleared"
|
||||
with open(sp._runtime_ctx.job_status_path) as fin:
|
||||
assert "failure" in fin.read(), "job status should be failure"
|
||||
|
||||
@@ -85,9 +85,7 @@ class WorkQueueTestBase(object):
|
||||
def test_multi_consumers(self):
|
||||
numConsumers = 10
|
||||
numItems = 200
|
||||
result = self.pool.starmap_async(
|
||||
consumer, [(self.wq, id) for id in range(numConsumers)]
|
||||
)
|
||||
result = self.pool.starmap_async(consumer, [(self.wq, id) for id in range(numConsumers)])
|
||||
producer(self.wq, 0, numItems, numConsumers)
|
||||
|
||||
logger.info("waiting for result")
|
||||
|
||||
Reference in New Issue
Block a user