mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
162 lines
6.1 KiB
Python
162 lines
6.1 KiB
Python
import os.path
|
|
import random
|
|
import time
|
|
import unittest
|
|
from typing import List, Tuple
|
|
|
|
from loguru import logger
|
|
|
|
from smallpond.execution.scheduler import ExecutorState
|
|
from smallpond.execution.task import PythonScriptTask, RuntimeContext
|
|
from smallpond.logical.dataset import DataSet, ParquetDataSet
|
|
from smallpond.logical.node import (
|
|
Context,
|
|
DataSetPartitionNode,
|
|
DataSourceNode,
|
|
LogicalPlan,
|
|
Node,
|
|
PythonScriptNode,
|
|
)
|
|
from tests.test_fabric import TestFabric
|
|
|
|
|
|
class RandomSleepTask(PythonScriptTask):
|
|
def __init__(self, *args, sleep_secs: float, fail_first_try: bool, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.sleep_secs = sleep_secs
|
|
self.fail_first_try = fail_first_try
|
|
|
|
def process(
|
|
self,
|
|
runtime_ctx: RuntimeContext,
|
|
input_datasets: List[DataSet],
|
|
output_path: str,
|
|
) -> bool:
|
|
logger.info(f"sleeping {self.sleep_secs} secs")
|
|
time.sleep(self.sleep_secs)
|
|
with open(os.path.join(output_path, self.output_filename), "w") as fout:
|
|
fout.write(f"{repr(self)}")
|
|
if self.fail_first_try and self.retry_count == 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
class RandomSleepNode(PythonScriptNode):
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
input_deps: Tuple[Node, ...],
|
|
*,
|
|
max_sleep_secs=5,
|
|
fail_first_try=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__(ctx, input_deps, **kwargs)
|
|
self.max_sleep_secs = max_sleep_secs
|
|
self.fail_first_try = fail_first_try
|
|
|
|
def spawn(self, *args, **kwargs) -> RandomSleepTask:
|
|
sleep_secs = random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
|
|
return RandomSleepTask(*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try)
|
|
|
|
|
|
class TestScheduler(TestFabric, unittest.TestCase):
|
|
def create_random_sleep_plan(self, npartitions, max_sleep_secs, fail_first_try=False):
|
|
ctx = Context()
|
|
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
|
|
data_files = DataSourceNode(ctx, dataset)
|
|
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions, partition_by_rows=True)
|
|
random_sleep = RandomSleepNode(
|
|
ctx,
|
|
(data_partitions,),
|
|
max_sleep_secs=max_sleep_secs,
|
|
fail_first_try=fail_first_try,
|
|
)
|
|
return LogicalPlan(ctx, random_sleep)
|
|
|
|
def check_executor_state(self, target_state: ExecutorState, nloops=200):
|
|
for _ in range(nloops):
|
|
latest_sched_state = self.get_latest_sched_state()
|
|
if any(executor.state == target_state for executor in latest_sched_state.remote_executors):
|
|
logger.info(f"found {target_state} executor in: {latest_sched_state.remote_executors}")
|
|
break
|
|
time.sleep(0.1)
|
|
else:
|
|
self.assertTrue(
|
|
False,
|
|
f"cannot find any executor in state {target_state}: {latest_sched_state.remote_executors}",
|
|
)
|
|
|
|
def test_standalone_mode(self):
|
|
plan = self.create_random_sleep_plan(npartitions=10, max_sleep_secs=1)
|
|
self.execute_plan(plan, num_executors=0)
|
|
|
|
def test_failed_executors(self):
|
|
num_exec = 6
|
|
num_fail = 4
|
|
plan = self.create_random_sleep_plan(npartitions=300, max_sleep_secs=10)
|
|
|
|
_, executors, processes = self.start_execution(
|
|
plan,
|
|
num_executors=num_exec,
|
|
secs_wq_poll_interval=0.1,
|
|
secs_executor_probe_interval=0.5,
|
|
console_log_level="WARNING",
|
|
)
|
|
latest_sched_state = self.get_latest_sched_state()
|
|
self.check_executor_state(ExecutorState.GOOD)
|
|
|
|
for i, (executor, process) in enumerate(random.sample(list(zip(executors, processes[1:])), k=num_fail)):
|
|
if i % 2 == 0:
|
|
logger.warning(f"kill executor: {executor}")
|
|
process.kill()
|
|
else:
|
|
logger.warning(f"skip probes: {executor}")
|
|
executor.skip_probes(latest_sched_state.ctx.max_num_missed_probes * 2)
|
|
|
|
self.join_running_procs()
|
|
latest_sched_state = self.get_latest_sched_state()
|
|
self.assertTrue(latest_sched_state.success)
|
|
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
|
|
self.assertLessEqual(
|
|
1,
|
|
len(latest_sched_state.stopped_executors),
|
|
f"remote_executors: {latest_sched_state.remote_executors}",
|
|
)
|
|
self.assertLessEqual(
|
|
num_fail / 2,
|
|
len(latest_sched_state.failed_executors),
|
|
f"remote_executors: {latest_sched_state.remote_executors}",
|
|
)
|
|
|
|
def test_speculative_scheduling(self):
|
|
for speculative_exec in ("disable", "enable", "aggressive"):
|
|
with self.subTest(speculative_exec=speculative_exec):
|
|
plan = self.create_random_sleep_plan(npartitions=100, max_sleep_secs=10)
|
|
self.execute_plan(
|
|
plan,
|
|
num_executors=3,
|
|
secs_wq_poll_interval=0.1,
|
|
secs_executor_probe_interval=0.5,
|
|
prioritize_retry=(speculative_exec == "aggressive"),
|
|
speculative_exec=speculative_exec,
|
|
)
|
|
latest_sched_state = self.get_latest_sched_state()
|
|
if speculative_exec == "disable":
|
|
self.assertEqual(len(latest_sched_state.abandoned_tasks), 0)
|
|
else:
|
|
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
|
|
|
|
def test_stop_executor_on_failure(self):
|
|
plan = self.create_random_sleep_plan(npartitions=3, max_sleep_secs=5, fail_first_try=True)
|
|
exec_plan = self.execute_plan(
|
|
plan,
|
|
num_executors=5,
|
|
secs_wq_poll_interval=0.1,
|
|
secs_executor_probe_interval=0.5,
|
|
check_result=False,
|
|
stop_executor_on_failure=True,
|
|
)
|
|
latest_sched_state = self.get_latest_sched_state()
|
|
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
|