reformat code with --line-length=150 (#18)

This commit is contained in:
Runji Wang
2025-03-05 22:46:23 +08:00
committed by GitHub
parent ed112db42a
commit 52ecc5e455
48 changed files with 794 additions and 2604 deletions

View File

@@ -14,19 +14,13 @@ from filelock import FileLock
def generate_url_and_domain() -> Tuple[str, str]:
domain_part = "".join(
random.choices(string.ascii_lowercase, k=random.randint(5, 15))
)
domain_part = "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 15)))
tld = random.choice(["com", "net", "org", "cn", "edu", "gov", "co", "io"])
domain = f"www.{domain_part}.{tld}"
path_segments = []
for _ in range(random.randint(1, 3)):
segment = "".join(
random.choices(
string.ascii_lowercase + string.digits, k=random.randint(3, 10)
)
)
segment = "".join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(3, 10)))
path_segments.append(segment)
path = "/" + "/".join(path_segments)
@@ -42,26 +36,18 @@ def generate_random_date() -> str:
start = datetime(2023, 1, 1, tzinfo=timezone.utc)
end = datetime(2023, 12, 31, tzinfo=timezone.utc)
delta = end - start
random_date = start + timedelta(
seconds=random.randint(0, int(delta.total_seconds()))
)
random_date = start + timedelta(seconds=random.randint(0, int(delta.total_seconds())))
return random_date.strftime("%Y-%m-%dT%H:%M:%SZ")
def generate_content() -> bytes:
target_length = (
random.randint(1000, 100000)
if random.random() < 0.8
else random.randint(100000, 1000000)
)
target_length = random.randint(1000, 100000) if random.random() < 0.8 else random.randint(100000, 1000000)
before = b"<!DOCTYPE html><html><head><title>Random Page</title></head><body>"
after = b"</body></html>"
total_before_after = len(before) + len(after)
fill_length = max(target_length - total_before_after, 0)
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[
:fill_length
]
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[:fill_length]
return before + filler + after
@@ -103,9 +89,7 @@ def generate_random_string(length: int) -> str:
def generate_random_url() -> str:
"""Generate a random URL"""
path = generate_random_string(random.randint(10, 20))
return (
f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
)
return f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
def generate_random_data() -> str:
@@ -138,9 +122,7 @@ def generate_url_parquet_files(output_dir: str, num_files: int = 10):
)
def generate_url_tsv_files(
output_dir: str, num_files: int = 10, lines_per_file: int = 100
):
def generate_url_tsv_files(output_dir: str, num_files: int = 10, lines_per_file: int = 100):
"""Generate multiple files, each containing a specified number of random data lines"""
os.makedirs(output_dir, exist_ok=True)
for i in range(num_files):
@@ -166,16 +148,12 @@ def generate_data(path: str = "tests/data"):
with FileLock(path + "/data.lock"):
print("Generating data...")
if not os.path.exists(path + "/mock_urls"):
generate_url_tsv_files(
output_dir=path + "/mock_urls", num_files=10, lines_per_file=100
)
generate_url_tsv_files(output_dir=path + "/mock_urls", num_files=10, lines_per_file=100)
generate_url_parquet_files(output_dir=path + "/mock_urls", num_files=10)
if not os.path.exists(path + "/arrow"):
generate_arrow_files(output_dir=path + "/arrow", num_files=10)
if not os.path.exists(path + "/large_array"):
concat_arrow_files(
input_dir=path + "/arrow", output_dir=path + "/large_array"
)
concat_arrow_files(input_dir=path + "/arrow", output_dir=path + "/large_array")
if not os.path.exists(path + "/long_path_list.txt"):
generate_long_path_list(path=path + "/long_path_list.txt")
except Exception as e:

View File

@@ -37,10 +37,7 @@ class TestArrow(TestFabric, unittest.TestCase):
with self.subTest(dataset_path=dataset_path):
metadata = parquet.read_metadata(dataset_path)
file_num_rows = metadata.num_rows
data_size = sum(
metadata.row_group(i).total_byte_size
for i in range(metadata.num_row_groups)
)
data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
row_range = RowRange(
path=dataset_path,
begin=100,
@@ -48,9 +45,7 @@ class TestArrow(TestFabric, unittest.TestCase):
data_size=data_size,
file_num_rows=file_num_rows,
)
expected = self._load_parquet_files([dataset_path]).slice(
offset=100, length=100
)
expected = self._load_parquet_files([dataset_path]).slice(offset=100, length=100)
actual = load_from_parquet_files([row_range])
self._compare_arrow_tables(expected, actual)
@@ -62,21 +57,15 @@ class TestArrow(TestFabric, unittest.TestCase):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ok = dump_to_parquet_files(expected, output_dir)
self.assertTrue(ok)
actual = self._load_parquet_files(
glob.glob(f"{output_dir}/*.parquet")
)
actual = self._load_parquet_files(glob.glob(f"{output_dir}/*.parquet"))
self._compare_arrow_tables(expected, actual)
def test_dump_load_empty_table(self):
# create empty table
empty_table = self._load_parquet_files(
["tests/data/arrow/data0.parquet"]
).slice(length=0)
empty_table = self._load_parquet_files(["tests/data/arrow/data0.parquet"]).slice(length=0)
self.assertEqual(empty_table.num_rows, 0)
# dump empty table
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
@@ -94,9 +83,7 @@ class TestArrow(TestFabric, unittest.TestCase):
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected_num_rows = sum(
parquet.read_metadata(file).num_rows for file in parquet_files
)
expected_num_rows = sum(parquet.read_metadata(file).num_rows for file in parquet_files)
with build_batch_reader_from_files(
parquet_files,
batch_size=expected_num_rows,
@@ -104,9 +91,7 @@ class TestArrow(TestFabric, unittest.TestCase):
) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter:
total_num_rows = 0
for batch in concurrent_iter:
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}"
)
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}")
self.assertLessEqual(batch.num_rows, expected_num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, expected_num_rows)
@@ -121,9 +106,7 @@ class TestArrow(TestFabric, unittest.TestCase):
table = self._load_parquet_files(parquet_files)
total_num_rows = 0
for batch in table.to_batches(max_chunksize=table.num_rows):
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}"
)
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}")
self.assertLessEqual(batch.num_rows, table.num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, table.num_rows)
@@ -135,26 +118,14 @@ class TestArrow(TestFabric, unittest.TestCase):
print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}")
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
self.assertTrue(
dump_to_parquet_files(
table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2
)
)
parquet_files = glob.glob(
os.path.join(output_dir, "arrow_schema_metadata*.parquet")
)
loaded_table = load_from_parquet_files(
parquet_files, table.column_names[:1]
)
self.assertTrue(dump_to_parquet_files(table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2))
parquet_files = glob.glob(os.path.join(output_dir, "arrow_schema_metadata*.parquet"))
loaded_table = load_from_parquet_files(parquet_files, table.column_names[:1])
print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}")
self.assertEqual(
table_with_meta.schema.metadata, loaded_table.schema.metadata
)
self.assertEqual(table_with_meta.schema.metadata, loaded_table.schema.metadata)
with parquet.ParquetFile(parquet_files[0]) as file:
print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}")
self.assertEqual(
table_with_meta.schema.metadata, file.schema_arrow.metadata
)
self.assertEqual(table_with_meta.schema.metadata, file.schema_arrow.metadata)
def test_load_mixed_string_types(self):
parquet_paths = glob.glob("tests/data/arrow/*.parquet")
@@ -166,9 +137,7 @@ class TestArrow(TestFabric, unittest.TestCase):
loaded_table = load_from_parquet_files(parquet_paths)
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
batch_reader = build_batch_reader_from_files(parquet_paths)
self.assertEqual(
table.num_rows * 2, sum(batch.num_rows for batch in batch_reader)
)
self.assertEqual(table.num_rows * 2, sum(batch.num_rows for batch in batch_reader))
@logger.catch(reraise=True, message="failed to load parquet files")
def _load_from_parquet_files_with_log(self, paths, columns):

View File

@@ -45,9 +45,7 @@ class TestBench(TestFabric, unittest.TestCase):
num_sort_partitions = 1 << 3
for shuffle_engine in ("duckdb", "arrow"):
for sort_engine in ("duckdb", "arrow", "polars"):
with self.subTest(
shuffle_engine=shuffle_engine, sort_engine=sort_engine
):
with self.subTest(shuffle_engine=shuffle_engine, sort_engine=sort_engine):
ctx = Context()
random_records = generate_random_records(
ctx,

View File

@@ -34,9 +34,7 @@ class TestCommon(TestFabric, unittest.TestCase):
npartitions = data.draw(st.integers(1, 2 * nelements))
items = list(range(nelements))
computed = split_into_rows(items, npartitions)
expected = [
get_nth_partition(items, n, npartitions) for n in range(npartitions)
]
expected = [get_nth_partition(items, n, npartitions) for n in range(npartitions)]
self.assertEqual(expected, computed)
@given(st.data())

View File

@@ -70,9 +70,7 @@ def test_flat_map(sp: Session):
# user need to specify the schema if can not be inferred from the mapping values
df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())]))
assert df3.to_arrow() == pa.table(
{"c": pa.array([None, None, None], type=pa.int64())}
)
assert df3.to_arrow() == pa.table({"c": pa.array([None, None, None], type=pa.int64())})
def test_map_batches(sp: Session):
@@ -99,10 +97,7 @@ def test_random_shuffle(sp: Session):
assert sorted(shuffled) == list(range(1000))
def count_inversions(arr: List[int]) -> int:
return sum(
sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j])
for i in range(len(arr))
)
return sum(sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) for i in range(len(arr)))
# check the shuffle is random enough
# the expected number of inversions is n*(n-1)/4 = 249750
@@ -158,9 +153,7 @@ def test_partial_sql(sp: Session):
# join
df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]}))
df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]}))
joined = sp.partial_sql(
"select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2
)
joined = sp.partial_sql("select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2)
assert joined.to_arrow() == pa.table(
{"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]},
schema=pa.schema(
@@ -193,10 +186,7 @@ def test_unpicklable_task_exception(sp: Session):
df.map(lambda x: logger.info("use outside logger")).to_arrow()
except Exception as ex:
assert "Can't pickle task" in str(ex)
assert (
"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
in str(ex)
)
assert "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." in str(ex)
else:
assert False, "expected exception"

View File

@@ -30,9 +30,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
dataset = ParquetDataSet([os.path.join(self.output_root_abspath, "*.parquet")])
self.assertEqual(num_urls, dataset.num_rows)
def _generate_parquet_dataset(
self, output_path, npartitions, num_rows, row_group_size
):
def _generate_parquet_dataset(self, output_path, npartitions, num_rows, row_group_size):
duckdb.sql(
f"""copy (
select range as i, range % {npartitions} as partition from range(0, {num_rows}) )
@@ -41,9 +39,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
)
return ParquetDataSet([f"{output_path}/**/*.parquet"])
def _check_partition_datasets(
self, orig_dataset: ParquetDataSet, partition_func, npartition
):
def _check_partition_datasets(self, orig_dataset: ParquetDataSet, partition_func, npartition):
# build partitioned datasets
partitioned_datasets = partition_func(npartition)
self.assertEqual(npartition, len(partitioned_datasets))
@@ -52,9 +48,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
sum(dataset.num_rows for dataset in partitioned_datasets),
)
# load as arrow table
loaded_table = arrow.concat_tables(
[dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets]
)
loaded_table = arrow.concat_tables([dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets])
self.assertEqual(orig_dataset.num_rows, loaded_table.num_rows)
# compare arrow tables
orig_table = orig_dataset.to_arrow_table(max_workers=1)
@@ -74,9 +68,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
def test_partition_by_files(self):
output_path = os.path.join(self.output_root_abspath, "test_partition_by_files")
orig_dataset = self._generate_parquet_dataset(
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
)
orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000)
num_files = len(orig_dataset.resolved_paths)
for npartition in range(1, num_files + 1):
for random_shuffle in (False, True):
@@ -84,17 +76,13 @@ class TestDataSet(TestFabric, unittest.TestCase):
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
self._check_partition_datasets(
orig_dataset,
lambda n: orig_dataset.partition_by_files(
n, random_shuffle=random_shuffle
),
lambda n: orig_dataset.partition_by_files(n, random_shuffle=random_shuffle),
npartition,
)
def test_partition_by_rows(self):
output_path = os.path.join(self.output_root_abspath, "test_partition_by_rows")
orig_dataset = self._generate_parquet_dataset(
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
)
orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000)
num_files = len(orig_dataset.resolved_paths)
for npartition in range(1, 2 * num_files + 1):
for random_shuffle in (False, True):
@@ -102,9 +90,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
self._check_partition_datasets(
orig_dataset,
lambda n: orig_dataset.partition_by_rows(
n, random_shuffle=random_shuffle
),
lambda n: orig_dataset.partition_by_rows(n, random_shuffle=random_shuffle),
npartition,
)
@@ -116,9 +102,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
self.assertEqual(len(dataset.resolved_paths), len(filenames))
def test_paths_with_char_ranges(self):
dataset_with_char_ranges = ParquetDataSet(
["tests/data/arrow/data[0-9].parquet"]
)
dataset_with_char_ranges = ParquetDataSet(["tests/data/arrow/data[0-9].parquet"])
dataset_with_wildcards = ParquetDataSet(["tests/data/arrow/*.parquet"])
self.assertEqual(
len(dataset_with_char_ranges.resolved_paths),
@@ -126,9 +110,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
)
def test_to_arrow_table_batch_reader(self):
memdb = duckdb.connect(
database=":memory:", config={"arrow_large_buffer_size": "true"}
)
memdb = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"})
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
@@ -137,24 +119,14 @@ class TestDataSet(TestFabric, unittest.TestCase):
print(f"dataset_path: {dataset_path}, conn: {conn}")
with self.subTest(dataset_path=dataset_path, conn=conn):
dataset = ParquetDataSet([dataset_path])
to_batches = dataset.to_arrow_table(
max_workers=1, conn=conn
).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
batch_reader = dataset.to_batch_reader(
batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn
)
with ConcurrentIter(
batch_reader, max_buffer_size=2
) as batch_reader:
to_batches = dataset.to_arrow_table(max_workers=1, conn=conn).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
batch_reader = dataset.to_batch_reader(batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn)
with ConcurrentIter(batch_reader, max_buffer_size=2) as batch_reader:
for batch_iter in (to_batches, batch_reader):
total_num_rows = 0
for batch in batch_iter:
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}"
)
self.assertLessEqual(
batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2
)
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}")
self.assertLessEqual(batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2)
total_num_rows += batch.num_rows
print(f"{dataset_path}: total_num_rows {total_num_rows}")
self.assertEqual(total_num_rows, dataset.num_rows)
@@ -167,8 +139,6 @@ def test_arrow_reader(benchmark, reader: str, dataset_path: str):
dataset = ParquetDataSet([dataset_path])
conn = None
if reader == "duckdb":
conn = duckdb.connect(
database=":memory:", config={"arrow_large_buffer_size": "true"}
)
conn = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"})
benchmark(dataset.to_arrow_table, conn=conn)
# result: arrow reader is 4x faster than duckdb reader in small dataset, 1.4x faster in large dataset

View File

@@ -7,9 +7,7 @@ from smallpond.io.arrow import cast_columns_to_large_string
from tests.test_fabric import TestFabric
@unittest.skipUnless(
importlib.util.find_spec("deltalake") is not None, "cannot find deltalake"
)
@unittest.skipUnless(importlib.util.find_spec("deltalake") is not None, "cannot find deltalake")
class TestDeltaLake(TestFabric, unittest.TestCase):
def test_read_write_deltalake(self):
from deltalake import DeltaTable, write_deltalake
@@ -20,9 +18,7 @@ class TestDeltaLake(TestFabric, unittest.TestCase):
):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
write_deltalake(output_dir, expected, large_dtypes=True)
dt = DeltaTable(output_dir)
self._compare_arrow_tables(expected, dt.to_pyarrow_table())
@@ -35,12 +31,8 @@ class TestDeltaLake(TestFabric, unittest.TestCase):
"tests/data/large_array/*.parquet",
):
parquet_files = glob.glob(dataset_path)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
table = cast_columns_to_large_string(
self._load_parquet_files(parquet_files)
)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
table = cast_columns_to_large_string(self._load_parquet_files(parquet_files))
write_deltalake(output_dir, table, large_dtypes=True, mode="overwrite")
write_deltalake(output_dir, table, large_dtypes=False, mode="append")
loaded_table = DeltaTable(output_dir).to_pyarrow_table()

View File

@@ -69,9 +69,7 @@ class OutputMsgPythonTask(PythonScriptTask):
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(
f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}"
)
logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}")
self.inject_fault()
return True
@@ -105,9 +103,7 @@ class CopyInputArrowNode(ArrowComputeNode):
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
return copy_input_arrow(runtime_ctx, input_tables, self.msg)
@@ -117,24 +113,18 @@ class CopyInputStreamNode(ArrowStreamNode):
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
return copy_input_stream(runtime_ctx, input_readers, self.msg)
def copy_input_arrow(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str
) -> arrow.Table:
def copy_input_arrow(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_tables[0]
def copy_input_stream(
runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str
) -> Iterable[arrow.Table]:
def copy_input_stream(runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str) -> Iterable[arrow.Table]:
for index, batch in enumerate(input_readers[0]):
logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
@@ -146,62 +136,44 @@ def copy_input_stream(
runtime_ctx.task.inject_fault()
def copy_input_batch(
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str
) -> arrow.Table:
def copy_input_batch(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_batches[0]
def copy_input_data_frame(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
def copy_input_data_frame(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def copy_input_data_frame_batch(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
def copy_input_data_frame_batch(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def merge_input_tables(
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]
) -> arrow.Table:
def merge_input_tables(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]) -> arrow.Table:
runtime_ctx.task.inject_fault()
output = arrow.concat_tables(input_batches)
logger.info(
f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}"
)
logger.info(f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}")
return output
def merge_input_data_frames(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
def merge_input_data_frames(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
output = pandas.concat(input_dfs)
logger.info(
f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}"
)
logger.info(f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}")
return output
def parse_url(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def parse_url(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
urls = input_tables[0].columns[0]
hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls]
return input_tables[0].append_column("host", arrow.array(hosts))
def nonzero_exit_code(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def nonzero_exit_code(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
import sys
if runtime_ctx.task._memory_boost == 1:
@@ -210,9 +182,7 @@ def nonzero_exit_code(
# create an empty file with a fixed name
def empty_file(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def empty_file(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
import os
with open(os.path.join(output_path, "file"), "w") as fout:
@@ -231,9 +201,7 @@ def split_url(urls: arrow.array) -> arrow.array:
return arrow.array(url_parts, type=arrow.list_(arrow.string()))
def choose_random_urls(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5
) -> arrow.Table:
def choose_random_urls(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5) -> arrow.Table:
# get the current running task
runtime_task = runtime_ctx.task
# access task-specific attributes
@@ -255,16 +223,12 @@ class TestExecution(TestFabric, unittest.TestCase):
def test_arrow_task(self):
for use_duckdb_reader in (False, True):
with self.subTest(use_duckdb_reader=use_duckdb_reader):
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_table = dataset.to_arrow_table()
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=7
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7)
if use_duckdb_reader:
data_partitions = ProjectionNode(
ctx,
@@ -274,9 +238,7 @@ class TestExecution(TestFabric, unittest.TestCase):
arrow_compute = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_arrow, msg="arrow compute"
),
process_func=functools.partial(copy_input_arrow, msg="arrow compute"),
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_compute",
output_path=output_dir,
@@ -285,9 +247,7 @@ class TestExecution(TestFabric, unittest.TestCase):
arrow_stream = ArrowStreamNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_stream, msg="arrow stream"
),
process_func=functools.partial(copy_input_stream, msg="arrow stream"),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
@@ -298,9 +258,7 @@ class TestExecution(TestFabric, unittest.TestCase):
arrow_batch = ArrowBatchNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_batch, msg="arrow batch"
),
process_func=functools.partial(copy_input_batch, msg="arrow batch"),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
@@ -314,12 +272,8 @@ class TestExecution(TestFabric, unittest.TestCase):
output_path=output_dir,
)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_compute_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_compute", "**/*.parquet")],
recursive=True,
@@ -334,21 +288,15 @@ class TestExecution(TestFabric, unittest.TestCase):
)
self._compare_arrow_tables(
data_table,
arrow_compute_output.to_arrow_table().select(
data_table.column_names
),
arrow_compute_output.to_arrow_table().select(data_table.column_names),
)
self._compare_arrow_tables(
data_table,
arrow_stream_output.to_arrow_table().select(
data_table.column_names
),
arrow_stream_output.to_arrow_table().select(data_table.column_names),
)
self._compare_arrow_tables(
data_table,
arrow_batch_output.to_arrow_table().select(
data_table.column_names
),
arrow_batch_output.to_arrow_table().select(data_table.column_names),
)
def test_pandas_task(self):
@@ -376,16 +324,10 @@ class TestExecution(TestFabric, unittest.TestCase):
output_path=output_dir,
cpu_limit=2,
)
data_sink = DataSinkNode(
ctx, (pandas_compute, pandas_batch), output_path=output_dir
)
data_sink = DataSinkNode(ctx, (pandas_compute, pandas_batch), output_path=output_dir)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
pandas_compute_output = ParquetDataSet(
[os.path.join(output_dir, "pandas_compute", "**/*.parquet")],
recursive=True,
@@ -394,21 +336,15 @@ class TestExecution(TestFabric, unittest.TestCase):
[os.path.join(output_dir, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self._compare_arrow_tables(
data_table, pandas_compute_output.to_arrow_table()
)
self._compare_arrow_tables(data_table, pandas_compute_output.to_arrow_table())
self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table())
def test_variable_length_input_datasets(self):
ctx = Context()
small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
small_partitions = DataSetPartitionNode(
ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7
)
large_partitions = DataSetPartitionNode(
ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7
)
small_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7)
large_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7)
arrow_batch = ArrowBatchNode(
ctx,
(small_partitions, large_partitions),
@@ -428,9 +364,7 @@ class TestExecution(TestFabric, unittest.TestCase):
cpu_limit=2,
)
plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch)))
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_batch_output = ParquetDataSet(
[os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")],
@@ -440,9 +374,7 @@ class TestExecution(TestFabric, unittest.TestCase):
[os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows
)
self.assertEqual(small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows,
pandas_batch_output.num_rows,
@@ -453,9 +385,7 @@ class TestExecution(TestFabric, unittest.TestCase):
# select columns when defining dataset
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=3, partition_by_rows=True
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3, partition_by_rows=True)
# projection as input of arrow node
generated_columns = ["filename", "file_row_number"]
urls_with_host = ArrowComputeNode(
@@ -480,9 +410,7 @@ class TestExecution(TestFabric, unittest.TestCase):
# unify different schemas
merged_diff_schemas = ProjectionNode(
ctx,
DataSetPartitionNode(
ctx, (distinct_urls_with_host, urls_with_host), npartitions=1
),
DataSetPartitionNode(ctx, (distinct_urls_with_host, urls_with_host), npartitions=1),
union_by_name=True,
)
host_partitions = HashPartitionNode(
@@ -513,9 +441,7 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
ctx.create_function(
"split_url",
split_url,
@@ -537,9 +463,7 @@ class TestExecution(TestFabric, unittest.TestCase):
npartitions = 1000
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions)
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=npartitions
)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(
@@ -552,13 +476,9 @@ class TestExecution(TestFabric, unittest.TestCase):
def test_many_producers_and_partitions(self):
ctx = Context()
npartitions = 10000
dataset = ParquetDataSet(
["tests/data/mock_urls/*.parquet"] * (npartitions * 10)
)
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * (npartitions * 10))
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=npartitions, cpu_limit=1
)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions, cpu_limit=1)
data_partitions.max_num_producer_tasks = 20
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
@@ -573,12 +493,8 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
output_msg = OutputMsgPythonNode(
ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5)
plan = LogicalPlan(ctx, output_msg)
runtime_ctx = RuntimeContext(
JobId.new(),
@@ -596,9 +512,7 @@ class TestExecution(TestFabric, unittest.TestCase):
data_files = DataSourceNode(ctx, dataset)
copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello")
copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello")
output_msg = OutputMsgPythonNode2(
ctx, (copy_input_arrow_node, copy_input_stream_node), "hello"
)
output_msg = OutputMsgPythonNode2(ctx, (copy_input_arrow_node, copy_input_stream_node), "hello")
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(plan)
@@ -606,9 +520,7 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
uniq_urls = SqlEngineNode(
ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB
)
uniq_urls = SqlEngineNode(ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB)
uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2)
uniq_url_count = SqlEngineNode(
ctx,
@@ -637,9 +549,7 @@ class TestExecution(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
data_sink = DataSinkNode(
ctx, (arrow_compute, arrow_stream), output_path=output_dir
)
data_sink = DataSinkNode(ctx, (arrow_compute, arrow_stream), output_path=output_dir)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(
plan,
@@ -652,17 +562,11 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
nonzero_exitcode = PythonScriptNode(
ctx, (data_files,), process_func=nonzero_exit_code
)
nonzero_exitcode = PythonScriptNode(ctx, (data_files,), process_func=nonzero_exit_code)
plan = LogicalPlan(ctx, nonzero_exitcode)
exec_plan = self.execute_plan(
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False
)
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False)
self.assertFalse(exec_plan.successful)
exec_plan = self.execute_plan(
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True
)
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True)
self.assertTrue(exec_plan.successful)
def test_manifest_only_data_sink(self):
@@ -675,9 +579,7 @@ class TestExecution(TestFabric, unittest.TestCase):
dataset = ParquetDataSet(filenames)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512)
data_sink = DataSinkNode(
ctx, (data_partitions,), output_path=output_dir, manifest_only=True
)
data_sink = DataSinkNode(ctx, (data_partitions,), output_path=output_dir, manifest_only=True)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(plan)
@@ -734,12 +636,8 @@ class TestExecution(TestFabric, unittest.TestCase):
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(url) as cnt from {0}"
)
distinct_url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
)
url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(url) as cnt from {0}")
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
merged_counts = DataSetPartitionNode(
ctx,
(
@@ -764,9 +662,7 @@ class TestExecution(TestFabric, unittest.TestCase):
r"select count(url) as cnt from {0}",
output_name="url_counts",
)
distinct_url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
)
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
merged_counts = DataSetPartitionNode(
ctx,
(
@@ -787,24 +683,14 @@ class TestExecution(TestFabric, unittest.TestCase):
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
empty_files1 = PythonScriptNode(
ctx, (data_partitions,), process_func=empty_file
)
empty_files2 = PythonScriptNode(
ctx, (data_partitions,), process_func=empty_file
)
empty_files1 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
empty_files2 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
link_path = os.path.join(self.runtime_ctx.output_root, "link")
copy_path = os.path.join(self.runtime_ctx.output_root, "copy")
copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input")
data_link = DataSinkNode(
ctx, (empty_files1, empty_files2), type="link", output_path=link_path
)
data_copy = DataSinkNode(
ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path
)
data_copy_input = DataSinkNode(
ctx, (data_partitions,), type="copy", output_path=copy_input_path
)
data_link = DataSinkNode(ctx, (empty_files1, empty_files2), type="link", output_path=link_path)
data_copy = DataSinkNode(ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path)
data_copy_input = DataSinkNode(ctx, (data_partitions,), type="copy", output_path=copy_input_path)
plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input)))
self.execute_plan(plan)
@@ -813,33 +699,19 @@ class TestExecution(TestFabric, unittest.TestCase):
self.assertEqual(21, len(os.listdir(copy_path)))
# file name should not be modified if no conflict
self.assertEqual(
set(
filename
for filename in os.listdir("tests/data/mock_urls")
if filename.endswith(".parquet")
),
set(
filename
for filename in os.listdir(copy_input_path)
if filename.endswith(".parquet")
),
set(filename for filename in os.listdir("tests/data/mock_urls") if filename.endswith(".parquet")),
set(filename for filename in os.listdir(copy_input_path) if filename.endswith(".parquet")),
)
def test_literal_datasets_as_data_sources(self):
ctx = Context()
num_rows = 10
query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)")
table_dataset = ArrowTableDataSet(
arrow.Table.from_arrays([list(range(num_rows))], names=["i"])
)
table_dataset = ArrowTableDataSet(arrow.Table.from_arrays([list(range(num_rows))], names=["i"]))
query_source = DataSourceNode(ctx, query_dataset)
table_source = DataSourceNode(ctx, table_dataset)
query_partitions = DataSetPartitionNode(
ctx, (query_source,), npartitions=num_rows, partition_by_rows=True
)
table_partitions = DataSetPartitionNode(
ctx, (table_source,), npartitions=num_rows, partition_by_rows=True
)
query_partitions = DataSetPartitionNode(ctx, (query_source,), npartitions=num_rows, partition_by_rows=True)
table_partitions = DataSetPartitionNode(ctx, (table_source,), npartitions=num_rows, partition_by_rows=True)
joined_rows = SqlEngineNode(
ctx,
(query_partitions, table_partitions),

View File

@@ -27,9 +27,7 @@ from tests.datagen import generate_data
generate_data()
def run_scheduler(
runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue
):
def run_scheduler(runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue):
runtime_ctx.initialize("scheduler")
scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue)))
retval = scheduler.run()
@@ -130,9 +128,7 @@ class TestFabric(unittest.TestCase):
process.kill()
process.join()
logger.info(
f"#{i} process {process.name} exited with code {process.exitcode}"
)
logger.info(f"#{i} process {process.name} exited with code {process.exitcode}")
def start_execution(
self,
@@ -189,11 +185,7 @@ class TestFabric(unittest.TestCase):
secs_wq_poll_interval=secs_wq_poll_interval,
secs_executor_probe_interval=secs_executor_probe_interval,
max_num_missed_probes=max_num_missed_probes,
fault_inject_prob=(
fault_inject_prob
if fault_inject_prob is not None
else self.fault_inject_prob
),
fault_inject_prob=(fault_inject_prob if fault_inject_prob is not None else self.fault_inject_prob),
enable_profiling=enable_profiling,
enable_diagnostic_metrics=enable_diagnostic_metrics,
remove_empty_parquet=remove_empty_parquet,
@@ -217,9 +209,7 @@ class TestFabric(unittest.TestCase):
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
)
self.latest_state = scheduler
self.executors = [
Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)
]
self.executors = [Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)]
self.processes = [
Process(
target=run_scheduler,
@@ -229,10 +219,7 @@ class TestFabric(unittest.TestCase):
name="scheduler",
)
]
self.processes += [
Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id)
for executor in self.executors
]
self.processes += [Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id) for executor in self.executors]
for process in reversed(self.processes):
process.start()
@@ -264,15 +251,9 @@ class TestFabric(unittest.TestCase):
self.assertTrue(latest_state.success)
return latest_state.exec_plan
def _load_parquet_files(
self, paths, filesystem: fsspec.AbstractFileSystem = None
) -> arrow.Table:
def _load_parquet_files(self, paths, filesystem: fsspec.AbstractFileSystem = None) -> arrow.Table:
def read_parquet_file(path):
return arrow.Table.from_batches(
parquet.ParquetFile(
path, buffer_size=16 * MB, filesystem=filesystem
).iter_batches()
)
return arrow.Table.from_batches(parquet.ParquetFile(path, buffer_size=16 * MB, filesystem=filesystem).iter_batches())
with ThreadPoolExecutor(16) as pool:
return arrow.concat_tables(pool.map(read_parquet_file, paths))

View File

@@ -17,9 +17,7 @@ class TestFilesystem(TestFabric, unittest.TestCase):
def test_pickle_trace(self):
with self.assertRaises(TypeError) as context:
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
thread = threading.Thread()
pickle_path = os.path.join(output_dir, "thread.pickle")
dump(thread, pickle_path)

View File

@@ -21,19 +21,11 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
def test_join_chunkmeta_inodes(self):
ctx = Context()
chunkmeta_dump = DataSourceNode(
ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"])
)
chunkmeta_partitions = HashPartitionNode(
ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"]
)
chunkmeta_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"]))
chunkmeta_partitions = HashPartitionNode(ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"])
inodes_dump = DataSourceNode(
ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"])
)
inodes_partitions = HashPartitionNode(
ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"]
)
inodes_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"]))
inodes_partitions = HashPartitionNode(ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"])
num_gc_chunks = SqlEngineNode(
ctx,
@@ -53,12 +45,8 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
partition_dim_a = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
)
partition_dim_b = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B"
)
partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A")
partition_dim_b = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B")
join_two_inputs = SqlEngineNode(
ctx,
(partition_dim_a, partition_dim_b),
@@ -73,9 +61,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
partition_dim_a = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
)
partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A")
partition_dim_a2 = EvenlyDistributedPartitionNode(
ctx,
(data_source,),
@@ -94,9 +80,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
)
plan = LogicalPlan(
ctx,
DataSetPartitionNode(
ctx, (join_two_inputs1, join_two_inputs2), npartitions=1
),
DataSetPartitionNode(ctx, (join_two_inputs1, join_two_inputs2), npartitions=1),
)
logger.info(str(plan))
with self.assertRaises(AssertionError) as context:

View File

@@ -31,9 +31,7 @@ from tests.test_fabric import TestFabric
class CalculatePartitionFromFilename(UserDefinedPartitionNode):
def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]:
partitioned_datasets: List[ParquetDataSet] = [
ParquetDataSet([]) for _ in range(self.npartitions)
]
partitioned_datasets: List[ParquetDataSet] = [ParquetDataSet([]) for _ in range(self.npartitions)]
for path in dataset.resolved_paths:
partition_idx = hash(path) % self.npartitions
partitioned_datasets[partition_idx].paths.append(path)
@@ -45,9 +43,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
count_rows = SqlEngineNode(
ctx,
(data_partitions,),
@@ -62,9 +58,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True)
count_rows = SqlEngineNode(
ctx,
(data_partitions,),
@@ -74,18 +68,14 @@ class TestPartition(TestFabric, unittest.TestCase):
)
plan = LogicalPlan(ctx, count_rows)
exec_plan = self.execute_plan(plan, num_executors=5)
self.assertEqual(
exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows
)
self.assertEqual(exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows)
def test_empty_dataset_partition(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
# create more partitions than files
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files * 2
)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=dataset.num_files * 2)
data_partitions.max_num_producer_tasks = 3
unique_urls = SqlEngineNode(
ctx,
@@ -95,9 +85,7 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
# nested partition
nested_partitioned_urls = EvenlyDistributedPartitionNode(
ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True
)
nested_partitioned_urls = EvenlyDistributedPartitionNode(ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True)
parsed_urls = ArrowComputeNode(
ctx,
(nested_partitioned_urls,),
@@ -106,18 +94,14 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, parsed_urls)
final_output = self.execute_plan(
plan, remove_empty_parquet=True, skip_task_with_empty_input=True
).final_output
final_output = self.execute_plan(plan, remove_empty_parquet=True, skip_task_with_empty_input=True).final_output
self.assertTrue(isinstance(final_output, ParquetDataSet))
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_hash_partition(self):
for engine_type in ("duckdb", "arrow"):
for partition_by_rows in (False, True):
for hive_partitioning in (
(False, True) if engine_type == "duckdb" else (False,)
):
for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,):
with self.subTest(
engine_type=engine_type,
partition_by_rows=partition_by_rows,
@@ -155,26 +139,16 @@ class TestPartition(TestFabric, unittest.TestCase):
exec_plan = self.execute_plan(plan)
self.assertEqual(
dataset.num_rows,
pc.sum(
exec_plan.final_output.to_arrow_table().column(
"row_count"
)
).as_py(),
pc.sum(exec_plan.final_output.to_arrow_table().column("row_count")).as_py(),
)
self.assertEqual(
npartitions,
len(exec_plan.final_output.load_partitioned_datasets(npartitions, DATA_PARTITION_COLUMN_NAME)),
)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, DATA_PARTITION_COLUMN_NAME
)
),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
exec_plan.get_output("hash_partitions").load_partitioned_datasets(
npartitions,
DATA_PARTITION_COLUMN_NAME,
hive_partitioning,
@@ -185,9 +159,7 @@ class TestPartition(TestFabric, unittest.TestCase):
def test_empty_hash_partition(self):
for engine_type in ("duckdb", "arrow"):
for partition_by_rows in (False, True):
for hive_partitioning in (
(False, True) if engine_type == "duckdb" else (False,)
):
for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,):
with self.subTest(
engine_type=engine_type,
partition_by_rows=partition_by_rows,
@@ -199,9 +171,7 @@ class TestPartition(TestFabric, unittest.TestCase):
npartitions = 3
npartitions_nested = 4
num_rows = 1
head_rows = SqlEngineNode(
ctx, (data_files,), f"select * from {{0}} limit {num_rows}"
)
head_rows = SqlEngineNode(ctx, (data_files,), f"select * from {{0}} limit {num_rows}")
data_partitions = DataSetPartitionNode(
ctx,
(head_rows,),
@@ -241,53 +211,31 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, select_every_row)
exec_plan = self.execute_plan(
plan, skip_task_with_empty_input=True
)
exec_plan = self.execute_plan(plan, skip_task_with_empty_input=True)
self.assertEqual(num_rows, exec_plan.final_output.num_rows)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, "hash_partitions"
)
),
len(exec_plan.final_output.load_partitioned_datasets(npartitions, "hash_partitions")),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions_nested, "nested_hash_partitions"
)
),
len(exec_plan.final_output.load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
npartitions, "hash_partitions"
)
),
len(exec_plan.get_output("hash_partitions").load_partitioned_datasets(npartitions, "hash_partitions")),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.get_output(
"nested_hash_partitions"
).load_partitioned_datasets(
npartitions_nested, "nested_hash_partitions"
)
exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")
),
)
if hive_partitioning:
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
exec_plan.get_output("hash_partitions").load_partitioned_datasets(
npartitions,
"hash_partitions",
hive_partitioning=True,
@@ -297,9 +245,7 @@ class TestPartition(TestFabric, unittest.TestCase):
self.assertEqual(
npartitions_nested,
len(
exec_plan.get_output(
"nested_hash_partitions"
).load_partitioned_datasets(
exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(
npartitions_nested,
"nested_hash_partitions",
hive_partitioning=True,
@@ -341,19 +287,11 @@ class TestPartition(TestFabric, unittest.TestCase):
exec_plan = self.execute_plan(plan)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
),
len(exec_plan.final_output.load_partitioned_datasets(npartitions, data_partition_column)),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output("input_partitions").load_partitioned_datasets(
npartitions, data_partition_column, hive_partitioning
)
),
len(exec_plan.get_output("input_partitions").load_partitioned_datasets(npartitions, data_partition_column, hive_partitioning)),
)
return exec_plan
@@ -376,12 +314,8 @@ class TestPartition(TestFabric, unittest.TestCase):
)
ctx = Context()
output1 = DataSourceNode(
ctx, dataset=exec_plan1.get_output("input_partitions")
)
output2 = DataSourceNode(
ctx, dataset=exec_plan2.get_output("input_partitions")
)
output1 = DataSourceNode(ctx, dataset=exec_plan1.get_output("input_partitions"))
output2 = DataSourceNode(ctx, dataset=exec_plan2.get_output("input_partitions"))
split_urls1 = LoadPartitionedDataSetNode(
ctx,
(output1,),
@@ -411,16 +345,8 @@ class TestPartition(TestFabric, unittest.TestCase):
plan = LogicalPlan(ctx, split_urls3)
exec_plan3 = self.execute_plan(plan)
# load each partition as arrow table and compare
final_output_partitions1 = (
exec_plan1.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
)
final_output_partitions3 = (
exec_plan3.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
)
final_output_partitions1 = exec_plan1.final_output.load_partitioned_datasets(npartitions, data_partition_column)
final_output_partitions3 = exec_plan3.final_output.load_partitioned_datasets(npartitions, data_partition_column)
self.assertEqual(npartitions, len(final_output_partitions3))
for x, y in zip(final_output_partitions1, final_output_partitions3):
self._compare_arrow_tables(x.to_arrow_table(), y.to_arrow_table())
@@ -433,9 +359,7 @@ class TestPartition(TestFabric, unittest.TestCase):
SqlEngineNode.default_cpu_limit = 1
SqlEngineNode.default_memory_limit = 1 * GB
initial_reduce = r"select host, count(*) as cnt from {0} group by host"
combine_reduce_results = (
r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
)
combine_reduce_results = r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
join_query = r"select host, cnt from {0} where (exists (select * from {1} where {1}.host = {0}.host)) and (exists (select * from {2} where {2}.host = {0}.host))"
partition_by_hosts = HashPartitionNode(
@@ -496,11 +420,7 @@ class TestPartition(TestFabric, unittest.TestCase):
url_count_by_3dims = SqlEngineNode(ctx, (partitioned_3dims,), initial_reduce)
url_count_by_hosts_x_urls2 = SqlEngineNode(
ctx,
(
ConsolidateNode(
ctx, url_count_by_3dims, ["host_partition", "url_partition"]
),
),
(ConsolidateNode(ctx, url_count_by_3dims, ["host_partition", "url_partition"]),),
combine_reduce_results,
output_name="url_count_by_hosts_x_urls2",
)
@@ -524,9 +444,7 @@ class TestPartition(TestFabric, unittest.TestCase):
output_name="join_count_by_hosts_x_urls2",
)
union_url_count_by_hosts = UnionNode(
ctx, (url_count_by_hosts1, url_count_by_hosts2)
)
union_url_count_by_hosts = UnionNode(ctx, (url_count_by_hosts1, url_count_by_hosts2))
union_url_count_by_hosts_x_urls = UnionNode(
ctx,
(
@@ -576,18 +494,14 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_files)
file_partitions1 = CalculatePartitionFromFilename(
ctx, (data_source,), npartitions=3, dimension="by_filename_hash1"
)
file_partitions1 = CalculatePartitionFromFilename(ctx, (data_source,), npartitions=3, dimension="by_filename_hash1")
url_count1 = SqlEngineNode(
ctx,
(file_partitions1,),
r"select host, count(*) as cnt from {0} group by host",
output_name="url_count1",
)
file_partitions2 = CalculatePartitionFromFilename(
ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2"
)
file_partitions2 = CalculatePartitionFromFilename(ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2")
url_count2 = SqlEngineNode(
ctx,
(file_partitions2,),
@@ -606,9 +520,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
evenly_dist_data_source = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files
)
evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files)
parquet_datasets = [ParquetDataSet([p]) for p in parquet_dataset.resolved_paths]
partitioned_data_source = UserPartitionedDataSourceNode(ctx, parquet_datasets)
@@ -631,9 +543,7 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
plan = LogicalPlan(
ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2])
)
plan = LogicalPlan(ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2]))
exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True)
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_host1").to_arrow_table(),
@@ -647,9 +557,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
evenly_dist_data_source = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files
)
evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files)
sql_query = SqlEngineNode(
ctx,
(evenly_dist_data_source,),

View File

@@ -65,6 +65,4 @@ def test_fstest(sp: Session):
def test_sort_mock_urls_v2(sp: Session):
sort_mock_urls_v2(
sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3
)
sort_mock_urls_v2(sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3)

View File

@@ -21,9 +21,7 @@ from tests.test_fabric import TestFabric
class RandomSleepTask(PythonScriptTask):
def __init__(
self, *args, sleep_secs: float, fail_first_try: bool, **kwargs
) -> None:
def __init__(self, *args, sleep_secs: float, fail_first_try: bool, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.sleep_secs = sleep_secs
self.fail_first_try = fail_first_try
@@ -58,24 +56,16 @@ class RandomSleepNode(PythonScriptNode):
self.fail_first_try = fail_first_try
def spawn(self, *args, **kwargs) -> RandomSleepTask:
sleep_secs = (
random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
)
return RandomSleepTask(
*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try
)
sleep_secs = random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
return RandomSleepTask(*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try)
class TestScheduler(TestFabric, unittest.TestCase):
def create_random_sleep_plan(
self, npartitions, max_sleep_secs, fail_first_try=False
):
def create_random_sleep_plan(self, npartitions, max_sleep_secs, fail_first_try=False):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=npartitions, partition_by_rows=True
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions, partition_by_rows=True)
random_sleep = RandomSleepNode(
ctx,
(data_partitions,),
@@ -87,13 +77,8 @@ class TestScheduler(TestFabric, unittest.TestCase):
def check_executor_state(self, target_state: ExecutorState, nloops=200):
for _ in range(nloops):
latest_sched_state = self.get_latest_sched_state()
if any(
executor.state == target_state
for executor in latest_sched_state.remote_executors
):
logger.info(
f"found {target_state} executor in: {latest_sched_state.remote_executors}"
)
if any(executor.state == target_state for executor in latest_sched_state.remote_executors):
logger.info(f"found {target_state} executor in: {latest_sched_state.remote_executors}")
break
time.sleep(0.1)
else:
@@ -121,9 +106,7 @@ class TestScheduler(TestFabric, unittest.TestCase):
latest_sched_state = self.get_latest_sched_state()
self.check_executor_state(ExecutorState.GOOD)
for i, (executor, process) in enumerate(
random.sample(list(zip(executors, processes[1:])), k=num_fail)
):
for i, (executor, process) in enumerate(random.sample(list(zip(executors, processes[1:])), k=num_fail)):
if i % 2 == 0:
logger.warning(f"kill executor: {executor}")
process.kill()
@@ -165,9 +148,7 @@ class TestScheduler(TestFabric, unittest.TestCase):
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
def test_stop_executor_on_failure(self):
plan = self.create_random_sleep_plan(
npartitions=3, max_sleep_secs=5, fail_first_try=True
)
plan = self.create_random_sleep_plan(npartitions=3, max_sleep_secs=5, fail_first_try=True)
exec_plan = self.execute_plan(
plan,
num_executors=5,

View File

@@ -5,9 +5,7 @@ from smallpond.dataframe import Session
def test_shutdown_cleanup(sp: Session):
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should exist"
assert os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should exist"
assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should exist"
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should exist"
# create some tasks and complete them
@@ -16,15 +14,9 @@ def test_shutdown_cleanup(sp: Session):
sp.shutdown()
# shutdown should clean up directories
assert not os.path.exists(
sp._runtime_ctx.queue_root
), "queue directory should be cleared"
assert not os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should be cleared"
assert not os.path.exists(
sp._runtime_ctx.temp_root
), "temp directory should be cleared"
assert not os.path.exists(sp._runtime_ctx.queue_root), "queue directory should be cleared"
assert not os.path.exists(sp._runtime_ctx.staging_root), "staging directory should be cleared"
assert not os.path.exists(sp._runtime_ctx.temp_root), "temp directory should be cleared"
with open(sp._runtime_ctx.job_status_path) as fin:
assert "success" in fin.read(), "job status should be success"
@@ -41,14 +33,8 @@ def test_shutdown_no_cleanup_on_failure(sp: Session):
sp.shutdown()
# shutdown should not clean up directories
assert os.path.exists(
sp._runtime_ctx.queue_root
), "queue directory should not be cleared"
assert os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should not be cleared"
assert os.path.exists(
sp._runtime_ctx.temp_root
), "temp directory should not be cleared"
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should not be cleared"
assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should not be cleared"
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should not be cleared"
with open(sp._runtime_ctx.job_status_path) as fin:
assert "failure" in fin.read(), "job status should be failure"

View File

@@ -85,9 +85,7 @@ class WorkQueueTestBase(object):
def test_multi_consumers(self):
numConsumers = 10
numItems = 200
result = self.pool.starmap_async(
consumer, [(self.wq, id) for id in range(numConsumers)]
)
result = self.pool.starmap_async(consumer, [(self.wq, id) for id in range(numConsumers)])
producer(self.wq, 0, numItems, numConsumers)
logger.info("waiting for result")