smallpond/tests/test_scheduler.py
2025-03-05 22:46:23 +08:00

162 lines
6.1 KiB
Python

import os.path
import random
import time
import unittest
from typing import List, Tuple
from loguru import logger
from smallpond.execution.scheduler import ExecutorState
from smallpond.execution.task import PythonScriptTask, RuntimeContext
from smallpond.logical.dataset import DataSet, ParquetDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
LogicalPlan,
Node,
PythonScriptNode,
)
from tests.test_fabric import TestFabric
class RandomSleepTask(PythonScriptTask):
def __init__(self, *args, sleep_secs: float, fail_first_try: bool, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.sleep_secs = sleep_secs
self.fail_first_try = fail_first_try
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(f"sleeping {self.sleep_secs} secs")
time.sleep(self.sleep_secs)
with open(os.path.join(output_path, self.output_filename), "w") as fout:
fout.write(f"{repr(self)}")
if self.fail_first_try and self.retry_count == 0:
return False
return True
class RandomSleepNode(PythonScriptNode):
def __init__(
self,
ctx: Context,
input_deps: Tuple[Node, ...],
*,
max_sleep_secs=5,
fail_first_try=False,
**kwargs,
):
super().__init__(ctx, input_deps, **kwargs)
self.max_sleep_secs = max_sleep_secs
self.fail_first_try = fail_first_try
def spawn(self, *args, **kwargs) -> RandomSleepTask:
sleep_secs = random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
return RandomSleepTask(*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try)
class TestScheduler(TestFabric, unittest.TestCase):
def create_random_sleep_plan(self, npartitions, max_sleep_secs, fail_first_try=False):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions, partition_by_rows=True)
random_sleep = RandomSleepNode(
ctx,
(data_partitions,),
max_sleep_secs=max_sleep_secs,
fail_first_try=fail_first_try,
)
return LogicalPlan(ctx, random_sleep)
def check_executor_state(self, target_state: ExecutorState, nloops=200):
for _ in range(nloops):
latest_sched_state = self.get_latest_sched_state()
if any(executor.state == target_state for executor in latest_sched_state.remote_executors):
logger.info(f"found {target_state} executor in: {latest_sched_state.remote_executors}")
break
time.sleep(0.1)
else:
self.assertTrue(
False,
f"cannot find any executor in state {target_state}: {latest_sched_state.remote_executors}",
)
def test_standalone_mode(self):
plan = self.create_random_sleep_plan(npartitions=10, max_sleep_secs=1)
self.execute_plan(plan, num_executors=0)
def test_failed_executors(self):
num_exec = 6
num_fail = 4
plan = self.create_random_sleep_plan(npartitions=300, max_sleep_secs=10)
_, executors, processes = self.start_execution(
plan,
num_executors=num_exec,
secs_wq_poll_interval=0.1,
secs_executor_probe_interval=0.5,
console_log_level="WARNING",
)
latest_sched_state = self.get_latest_sched_state()
self.check_executor_state(ExecutorState.GOOD)
for i, (executor, process) in enumerate(random.sample(list(zip(executors, processes[1:])), k=num_fail)):
if i % 2 == 0:
logger.warning(f"kill executor: {executor}")
process.kill()
else:
logger.warning(f"skip probes: {executor}")
executor.skip_probes(latest_sched_state.ctx.max_num_missed_probes * 2)
self.join_running_procs()
latest_sched_state = self.get_latest_sched_state()
self.assertTrue(latest_sched_state.success)
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
self.assertLessEqual(
1,
len(latest_sched_state.stopped_executors),
f"remote_executors: {latest_sched_state.remote_executors}",
)
self.assertLessEqual(
num_fail / 2,
len(latest_sched_state.failed_executors),
f"remote_executors: {latest_sched_state.remote_executors}",
)
def test_speculative_scheduling(self):
for speculative_exec in ("disable", "enable", "aggressive"):
with self.subTest(speculative_exec=speculative_exec):
plan = self.create_random_sleep_plan(npartitions=100, max_sleep_secs=10)
self.execute_plan(
plan,
num_executors=3,
secs_wq_poll_interval=0.1,
secs_executor_probe_interval=0.5,
prioritize_retry=(speculative_exec == "aggressive"),
speculative_exec=speculative_exec,
)
latest_sched_state = self.get_latest_sched_state()
if speculative_exec == "disable":
self.assertEqual(len(latest_sched_state.abandoned_tasks), 0)
else:
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
def test_stop_executor_on_failure(self):
plan = self.create_random_sleep_plan(npartitions=3, max_sleep_secs=5, fail_first_try=True)
exec_plan = self.execute_plan(
plan,
num_executors=5,
secs_wq_poll_interval=0.1,
secs_executor_probe_interval=0.5,
check_result=False,
stop_executor_on_failure=True,
)
latest_sched_state = self.get_latest_sched_state()
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)