Signed-off-by: Runji Wang <runji@deepseek.com>
This commit is contained in:
Runji Wang 2025-03-02 22:28:52 +08:00
parent 5a56a052bf
commit 9492a2872a
7 changed files with 293 additions and 631 deletions

View File

@ -13,22 +13,26 @@ import ray
import ray.exceptions import ray.exceptions
from loguru import logger from loguru import logger
from smallpond.execution.manager import JobManager
from smallpond.execution.task import Task from smallpond.execution.task import Task
from smallpond.io.filesystem import remove_path from smallpond.io.filesystem import remove_path
from smallpond.logical.dataset import * from smallpond.logical.dataset import *
from smallpond.logical.node import * from smallpond.logical.node import *
from smallpond.logical.optimizer import Optimizer from smallpond.logical.optimizer import Optimizer
from smallpond.logical.planner import Planner from smallpond.logical.planner import Planner
from smallpond.session import SessionBase
class Session(SessionBase): class Session:
# Extended session class with additional methods to create DataFrames. # Extended session class with additional methods to create DataFrames.
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) """
Create a smallpond environment.
"""
self._job_manager = JobManager(**kwargs)
self._ctx = Context()
self._runtime_ctx = self._job_manager.runtime_ctx
self._nodes: List[Node] = [] self._nodes: List[Node] = []
self._node_to_tasks: Dict[Node, List[Task]] = {} self._node_to_tasks: Dict[Node, List[Task]] = {}
""" """
When a DataFrame is evaluated, the tasks of the logical plan are stored here. When a DataFrame is evaluated, the tasks of the logical plan are stored here.
@ -158,52 +162,6 @@ class Session(SessionBase):
return return
self._shutdown_called = True self._shutdown_called = True
# log status
finished = self._all_tasks_finished()
with open(self._runtime_ctx.job_status_path, "a") as fout:
status = "success" if finished else "failure"
fout.write(f"{status}@{datetime.now():%Y-%m-%d-%H-%M-%S}\n")
# clean up runtime directories if success
if finished:
logger.info("all tasks are finished, cleaning up")
self._runtime_ctx.cleanup(remove_output_root=self.config.remove_output_root)
else:
logger.warning("tasks are not finished!")
super().shutdown()
def _summarize_task(self) -> Tuple[int, int]:
"""
Return the total number of tasks and the number of tasks that are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
if task._dataset_ref is not None
]
ready_tasks, _ = ray.wait(
dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False
)
return len(dataset_refs), len(ready_tasks)
def _all_tasks_finished(self) -> bool:
"""
Check if all tasks are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
]
try:
ray.get(dataset_refs, timeout=0)
except Exception:
# GetTimeoutError is raised if any task is not finished
# RuntimeError is raised if any task failed
return False
return True
class DataFrame: class DataFrame:
@ -216,7 +174,7 @@ class DataFrame:
def __init__(self, session: Session, plan: Node, recompute: bool = False): def __init__(self, session: Session, plan: Node, recompute: bool = False):
self.session = session self.session = session
self.plan = plan self.plan = plan
self.optimized_plan: Optional[Node] = None # self.optimized_plan: Optional[Node] = None
self.need_recompute = recompute self.need_recompute = recompute
"""Whether to recompute the data regardless of whether it's already computed.""" """Whether to recompute the data regardless of whether it's already computed."""
@ -229,17 +187,17 @@ class DataFrame:
""" """
Get or create tasks to compute the data. Get or create tasks to compute the data.
""" """
# optimize the plan # # optimize the plan
if self.optimized_plan is None: # if self.optimized_plan is None:
logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}") # logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
self.optimized_plan = Optimizer( # self.optimized_plan = Optimizer(
exclude_nodes=set(self.session._node_to_tasks.keys()) # exclude_nodes=set(self.session._node_to_tasks.keys())
).visit(self.plan) # ).visit(self.plan)
logger.info( # logger.info(
f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}" # f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
) # )
# return the tasks if already created # return the tasks if already created
if tasks := self.session._node_to_tasks.get(self.optimized_plan): if tasks := self.session._node_to_tasks.get(self.plan):
return tasks return tasks
# remove all completed task files if recompute is needed # remove all completed task files if recompute is needed
@ -247,16 +205,16 @@ class DataFrame:
remove_path( remove_path(
os.path.join( os.path.join(
self.session._runtime_ctx.completed_task_dir, self.session._runtime_ctx.completed_task_dir,
str(self.optimized_plan.id), str(self.plan.id),
) )
) )
logger.info(f"cleared all results of {self.optimized_plan!r}") logger.info(f"cleared all results of {self.plan!r}")
# create tasks for the optimized plan # create tasks for the optimized plan
planner = Planner(self.session._runtime_ctx) planner = Planner(self.session._runtime_ctx)
# let planner update self.session._node_to_tasks # let planner update self.session._node_to_tasks
planner.node_to_tasks = self.session._node_to_tasks planner.node_to_tasks = self.session._node_to_tasks
return planner.visit(self.optimized_plan) return planner.visit(self.plan)
def is_computed(self) -> bool: def is_computed(self) -> bool:
""" """

103
smallpond/execution/_ray.py Normal file
View File

@ -0,0 +1,103 @@
import copy
import os
import ray
from loguru import logger
from smallpond.common import DEFAULT_MAX_RETRY_COUNT
from smallpond.execution.task import Task
from smallpond.execution.workqueue import WorkStatus
from smallpond.io.filesystem import dump, load
from smallpond.logical.dataset import DataSet
def run_on_ray(task: Task) -> ray.ObjectRef:
"""
Run the task on Ray.
Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset.
A `_dataset_ref` attribute is added to the task to store the reference.
"""
if task._dataset_ref is not None:
# already started
return task._dataset_ref
# read the output dataset if the task has already finished
if os.path.exists(task.ray_dataset_path):
logger.info(f"task {task.key} already finished, skipping")
output = load(task.ray_dataset_path)
task._dataset_ref = ray.put(output)
return task._dataset_ref
task = copy.copy(task)
task.input_deps = {dep_key: None for dep_key in task.input_deps}
@ray.remote
def exec_task(task: Task, *inputs: DataSet) -> DataSet:
import multiprocessing as mp
import os
from pathlib import Path
from loguru import logger
# ray use a process pool to execute tasks
# we set the current process name to the task name
# so that we can see task name in the logs
mp.current_process().name = task.key
# probe the retry count
task.retry_count = 0
while os.path.exists(task.ray_marker_path):
task.retry_count += 1
if task.retry_count > DEFAULT_MAX_RETRY_COUNT:
raise RuntimeError(
f"task {task.key} failed after {task.retry_count} retries"
)
if task.retry_count > 0:
logger.warning(
f"task {task.key} is being retried for the {task.retry_count}th time"
)
# create the marker file
Path(task.ray_marker_path).touch()
# put the inputs into the task
assert len(inputs) == len(task.input_deps)
task.input_datasets = list(inputs)
# execute the task
status = task.exec()
if status != WorkStatus.SUCCEED:
raise task.exception or RuntimeError(
f"task {task.key} failed with status {status}"
)
# dump the output dataset atomically
os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True)
dump(task.output, task.ray_dataset_path, atomic_write=True)
return task.output
# this shows as {"name": ...} in timeline
exec_task._function_name = repr(task)
remote_function = exec_task.options(
# ray task name
# do not include task id so that they can be grouped by node in ray dashboard
name=f"{task.node_id}.{task.__class__.__name__}",
num_cpus=task.cpu_limit,
num_gpus=task.gpu_limit,
memory=int(task.memory_limit),
# note: `exec_on_scheduler` is ignored here,
# because dataset is distributed on ray
)
try:
task._dataset_ref = remote_function.remote(
task, *[run_on_ray(dep) for dep in task.input_deps.values()]
)
except RuntimeError as e:
if (
"SimpleQueue objects should only be shared between processes through inheritance"
in str(e)
):
raise RuntimeError(
f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n"
f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
) from e
raise e
return task._dataset_ref

View File

@ -212,14 +212,10 @@ class JobManager(object):
sched_state_observers.insert(0, sched_state_exporter) sched_state_observers.insert(0, sched_state_exporter)
if os.path.exists(self.runtime_ctx.sched_state_path): if os.path.exists(self.runtime_ctx.sched_state_path):
logger.warning( self.scheduler = Scheduler.recover_from_file(self.runtime_ctx.sched_state_path, sched_state_observers)
f"loading scheduler state from: {self.runtime_ctx.sched_state_path}"
)
self.scheduler: Scheduler = load(self.runtime_ctx.sched_state_path)
self.scheduler.sched_epoch += 1
self.scheduler.sched_state_observers = sched_state_observers
else: else:
self.scheduler = Scheduler( self.scheduler = Scheduler(
ctx=self.runtime_ctx,
max_retry_count=max_retry_count, max_retry_count=max_retry_count,
max_fail_count=max_fail_count, max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry, prioritize_retry=prioritize_retry,

View File

@ -1,5 +1,6 @@
import copy import copy
import cProfile import cProfile
from datetime import datetime
import itertools import itertools
import multiprocessing as mp import multiprocessing as mp
import os import os
@ -12,7 +13,18 @@ from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
import numpy as np import numpy as np
from loguru import logger from loguru import logger
@ -40,7 +52,7 @@ from smallpond.execution.workqueue import (
WorkQueueInMemory, WorkQueueInMemory,
WorkQueueOnFilesystem, WorkQueueOnFilesystem,
) )
from smallpond.io.filesystem import dump, remove_path from smallpond.io.filesystem import dump, load, remove_path
from smallpond.logical.node import LogicalPlan, Node from smallpond.logical.node import LogicalPlan, Node
from smallpond.utility import cprofile_to_string from smallpond.utility import cprofile_to_string
@ -359,10 +371,9 @@ class Scheduler(object):
""" """
large_num_nontrivial_tasks = 200 if pytest_running() else 20000 large_num_nontrivial_tasks = 200 if pytest_running() else 20000
StateCallback = Callable[["Scheduler"], Any]
class StateObserver(object): class StateObserver(object):
def __init__(self, callback: "Scheduler.StateCallback" = None) -> None: def __init__(self, callback: Callable[["Scheduler"], Any] = None) -> None:
assert callback is None or isinstance(callback, Callable) assert callback is None or isinstance(callback, Callable)
self.enabled = True self.enabled = True
self.callback = callback self.callback = callback
@ -380,6 +391,7 @@ class Scheduler(object):
def __init__( def __init__(
self, self,
ctx: RuntimeContext,
*, *,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT, max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT, max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
@ -390,18 +402,46 @@ class Scheduler(object):
remove_output_root: bool = False, remove_output_root: bool = False,
sched_state_observers: Optional[List[StateObserver]] = None, sched_state_observers: Optional[List[StateObserver]] = None,
) -> None: ) -> None:
"""
Initialize the scheduler.
Parameters
----------
ctx: RuntimeContext
The runtime context.
max_retry_count: int, optional
The maximum retry count. Default to 5.
max_fail_count: int, optional
The maximum fail count. Default to 3.
prioritize_retry: bool, optional
Whether to prioritize retry. Default to False.
speculative_exec: Literal["disable", "enable", "aggressive"], optional
If "enable", long-running tasks will be rescheduled on other executors.
If "aggressive", it will be more aggressive to reschedule long-running tasks.
If "disable", no speculative execution will be performed.
Default to "enable".
stop_executor_on_failure: bool, optional
Whether to stop the executor on failure. Default to False.
nonzero_exitcode_as_oom: bool, optional
Whether to treat non-zero exit code as out-of-memory. Default to False.
remove_output_root: bool, optional
Whether to remove the output root on exit. Default to False.
sched_state_observers: Optional[List[StateObserver]], optional
The state observers.
"""
# configs
self.ctx = ctx
self.max_retry_count = max_retry_count self.max_retry_count = max_retry_count
self.max_fail_count = max_fail_count self.max_fail_count = max_fail_count
self.standalone_mode = self.ctx.num_executors == 0 self.standalone_mode = self.ctx.num_executors == 0
self.prioritize_retry = prioritize_retry self.prioritize_retry = prioritize_retry
self.disable_speculative_exec = speculative_exec == "disable" self.speculative_exec: Literal["disable", "enable", "aggressive"] = (
self.aggressive_speculative_exec = speculative_exec == "aggressive" speculative_exec
)
self.stop_executor_on_failure = stop_executor_on_failure self.stop_executor_on_failure = stop_executor_on_failure
self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom
self.remove_output_root = remove_output_root self.remove_output_root = remove_output_root
self.sched_state_observers: List[Scheduler.StateObserver] = ( self.sched_state_observers = sched_state_observers or []
sched_state_observers or []
)
self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2 self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2
# task states # task states
self.local_queue: List[Task] = [] self.local_queue: List[Task] = []
@ -416,30 +456,47 @@ class Scheduler(object):
self.local_executor = LocalExecutor.create(self.ctx, "localhost") self.local_executor = LocalExecutor.create(self.ctx, "localhost")
self.available_executors = {self.local_executor.id: self.local_executor} self.available_executors = {self.local_executor.id: self.local_executor}
# other runtime states # other runtime states
self.sched_running = False self.failure = False
self.sched_start_time = 0 self.sched_start_time = time.time()
self.last_executor_probe_time = 0 self.last_executor_probe_time = 0
self.last_state_notify_time = 0 self.last_state_notify_time = 0
self.probe_epoch = 0 self.probe_epoch = 0
self.sched_epoch = 0 self.sched_epoch = 0
self._post_init()
@staticmethod
def recover_from_file(
sched_state_path: str, sched_state_observers: List[StateObserver]
) -> "Scheduler":
"""
Recover the scheduler from the previous run.
"""
logger.warning(f"loading scheduler state from: {sched_state_path}")
self: Scheduler = load(sched_state_path)
self.sched_epoch += 1
self.prioritize_retry = True
# observers are not pickled, so we need to re-add them
self.sched_state_observers = sched_state_observers
self._post_init()
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
del state["sched_state_observers"] del state["sched_state_observers"]
del state["profiler"]
return state return state
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__.update(state) self.__dict__.update(state)
self.sched_state_observers = [] self.sched_state_observers = []
self.profiler = None
@property @property
def elapsed_time(self): def elapsed_time(self):
return time.time() - self.sched_start_time return time.time() - self.sched_start_time
@property
def success(self) -> bool:
return self.exec_plan.root_task.key in self.succeeded_tasks
@property @property
def progress(self) -> Tuple[int, int, float]: def progress(self) -> Tuple[int, int, float]:
num_succeeded = len(self.succeeded_nontrivial_tasks) num_succeeded = len(self.succeeded_nontrivial_tasks)
@ -582,7 +639,7 @@ class Scheduler(object):
for executor in self.working_executors: for executor in self.working_executors:
for idx, item in enumerate(executor.running_works.values()): for idx, item in enumerate(executor.running_works.values()):
aggressive_retry = ( aggressive_retry = (
self.aggressive_speculative_exec self.speculative_exec == "aggressive"
and len(self.good_executors) >= self.ctx.num_executors and len(self.good_executors) >= self.ctx.num_executors
) )
short_sched_queue = len(self.sched_queue) < len(self.good_executors) short_sched_queue = len(self.sched_queue) < len(self.good_executors)
@ -644,7 +701,7 @@ class Scheduler(object):
for executor in self.working_executors: for executor in self.working_executors:
executor.probe(self.probe_epoch) executor.probe(self.probe_epoch)
# start speculative execution of tasks # start speculative execution of tasks
if not self.disable_speculative_exec: if self.speculative_exec != "disable":
self.start_speculative_execution() self.start_speculative_execution()
def update_executor_states(self): def update_executor_states(self):
@ -777,7 +834,7 @@ class Scheduler(object):
return task return task
@logger.catch(reraise=pytest_running(), message="failed to clean temp files") @logger.catch(reraise=pytest_running(), message="failed to clean temp files")
def clean_temp_files(self, pool: ThreadPoolExecutor): def clean_temp_files(self):
remove_path(self.ctx.queue_root) remove_path(self.ctx.queue_root)
remove_path(self.ctx.temp_root) remove_path(self.ctx.temp_root)
remove_path(self.ctx.staging_root) remove_path(self.ctx.staging_root)
@ -786,7 +843,10 @@ class Scheduler(object):
logger.info( logger.info(
f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..." f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..."
) )
assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)) with ThreadPoolExecutor(32) as pool:
assert list(
pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)
)
@logger.catch(reraise=pytest_running(), message="failed to export task metrics") @logger.catch(reraise=pytest_running(), message="failed to export task metrics")
def export_task_metrics(self): def export_task_metrics(self):
@ -963,55 +1023,45 @@ class Scheduler(object):
def log_current_status(self): def log_current_status(self):
with open(self.ctx.job_status_path, "w") as fout: with open(self.ctx.job_status_path, "w") as fout:
if self.sched_running: if self.failure:
status = "running"
elif self.success:
status = "success"
else:
status = "failure" status = "failure"
fout.write(f"{status}@{int(time.time())}") else:
status = "running"
fout.write(f"{status}@{datetime.now().isoformat()}")
def run(self, exec_plan: ExecutionPlan) -> bool: def _post_init(self):
""" """
Run the execution plan. Common initialization after startup or recovery.
""" """
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}" mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
logger.info( logger.info(
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}" f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"
) )
perf_profile = None
if self.ctx.enable_profiling: if self.ctx.enable_profiling:
perf_profile = cProfile.Profile() self.profiler = cProfile.Profile()
perf_profile.enable() self.profiler.enable()
self.sched_running = True
self.sched_start_time = time.time()
self.last_executor_probe_time = 0
self.last_state_notify_time = 0
self.prioritize_retry |= self.sched_epoch > 0
if self.local_queue or self.sched_queue:
pending_tasks = [
item
for item in self.local_queue + self.sched_queue
if isinstance(item, Task)
]
self.local_queue.clear()
self.sched_queue.clear()
logger.info(
f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..."
)
self.try_enqueue(pending_tasks)
with ThreadPoolExecutor(32) as pool: with ThreadPoolExecutor(32) as pool:
self.sched_running = True
self.sched_start_time = time.time()
self.last_executor_probe_time = 0
self.last_state_notify_time = 0
self.prioritize_retry |= self.sched_epoch > 0
if self.local_queue or self.sched_queue:
pending_tasks = [
item
for item in self.local_queue + self.sched_queue
if isinstance(item, Task)
]
self.local_queue.clear()
self.sched_queue.clear()
logger.info(
f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..."
)
self.try_enqueue(pending_tasks)
if self.sched_epoch == 0:
leaf_tasks = self.exec_plan.leaves
logger.info(
f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..."
)
self.try_enqueue(leaf_tasks)
self.log_overall_progress() self.log_overall_progress()
while (num_finished_tasks := self.process_finished_tasks(pool)) > 0: while (num_finished_tasks := self.process_finished_tasks(pool)) > 0:
logger.info( logger.info(
@ -1019,54 +1069,64 @@ class Scheduler(object):
) )
self.log_overall_progress() self.log_overall_progress()
earlier_running_tasks = [ earlier_running_tasks = [
item for item in self.running_works if isinstance(item, Task) item for item in self.running_works if isinstance(item, Task)
] ]
if earlier_running_tasks: if earlier_running_tasks:
logger.info( logger.info(
f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..." f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..."
) )
self.try_enqueue(earlier_running_tasks) self.try_enqueue(earlier_running_tasks)
self.suspend_good_executors() self.suspend_good_executors()
self.add_state_observer( self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status))
Scheduler.StateObserver(Scheduler.log_current_status) self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs))
) self.notify_state_observers(force_notify=True)
self.add_state_observer(
Scheduler.StateObserver(Scheduler.export_timeline_figs) def run(self, exec_plan: ExecutionPlan) -> bool:
) """
Run the execution plan.
"""
leaf_tasks = exec_plan.leaves
logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...")
self.try_enqueue(leaf_tasks)
try:
with ThreadPoolExecutor(32) as pool:
self.local_executor.start(pool)
self.sched_loop(pool, exec_plan.root_task)
finally:
logger.info(f"schedule loop stopped")
self.notify_state_observers(force_notify=True) self.notify_state_observers(force_notify=True)
try: logger.success(f"final output path: {exec_plan.final_output_path}")
self.local_executor.start(pool) logger.info(
self.sched_loop(pool) f"analyzed plan:{os.linesep}{exec_plan.analyzed_logical_plan.explain_str()}"
finally: )
logger.info(f"schedule loop stopped")
self.sched_running = False
self.notify_state_observers(force_notify=True)
self.export_task_metrics()
self.stop_executors()
# if --output_path is specified, remove the output root as well return True
if self.remove_output_root or self.ctx.final_output_path:
remove_path(self.ctx.staging_root)
remove_path(self.ctx.output_root)
if self.success: def cleanup(self):
self.clean_temp_files(pool) self.export_task_metrics()
logger.success(f"final output path: {self.exec_plan.final_output_path}") self.stop_executors()
logger.info(
f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}"
)
if perf_profile is not None: # if --output_path is specified, remove the output root as well
if self.remove_output_root:
remove_path(self.ctx.staging_root)
remove_path(self.ctx.output_root)
if not self.failure:
self.clean_temp_files()
if self.ctx.final_output_path:
remove_path(self.ctx.final_output_path)
if self.profiler is not None:
logger.debug( logger.debug(
f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}" f"scheduler perf profile:{os.linesep}{cprofile_to_string(self.profiler)}"
) )
logger.info(f"scheduler of job {self.ctx.job_id} exits") logger.info(f"scheduler of job {self.ctx.job_id} exits")
logger.complete() logger.complete()
return self.success
def try_enqueue(self, tasks: Union[Iterable[Task], Task]): def try_enqueue(self, tasks: Union[Iterable[Task], Task]):
tasks = tasks if isinstance(tasks, Iterable) else [tasks] tasks = tasks if isinstance(tasks, Iterable) else [tasks]
@ -1092,15 +1152,15 @@ class Scheduler(object):
else: else:
self.sched_queue.append(task) self.sched_queue.append(task)
def sched_loop(self, pool: ThreadPoolExecutor) -> bool: def sched_loop(self, pool: ThreadPoolExecutor, task: Task):
"""
Run the scheduler loop until the task is finished or failed.
"""
has_progress = True has_progress = True
do_notify = False do_notify = False
if self.success: while not self.failure and self.tasks[task.key].status == WorkStatus.INCOMPLETE:
logger.success(f"job already succeeded, stopping scheduler ...")
return True
while self.sched_running:
self.probe_executors() self.probe_executors()
self.update_executor_states() self.update_executor_states()
@ -1153,9 +1213,6 @@ class Scheduler(object):
has_progress |= self.process_finished_tasks(pool) > 0 has_progress |= self.process_finished_tasks(pool) > 0
# out of loop
return self.success
def dispatch_tasks(self, pool: ThreadPoolExecutor): def dispatch_tasks(self, pool: ThreadPoolExecutor):
# sort pending tasks # sort pending tasks
item_sort_key = ( item_sort_key = (
@ -1240,6 +1297,10 @@ class Scheduler(object):
return num_dispatched_tasks return num_dispatched_tasks
def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int: def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int:
"""
Process finished tasks from all executors.
Return the number of finished tasks.
"""
pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values()) pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values())
num_finished_tasks = 0 num_finished_tasks = 0
@ -1309,7 +1370,7 @@ class Scheduler(object):
f"task failed too many times: {finished_task}, stopping ..." f"task failed too many times: {finished_task}, stopping ..."
) )
self.stop_executors() self.stop_executors()
self.sched_running = False self.failure = True
if not executor.local and finished_task.oom( if not executor.local and finished_task.oom(
self.nonzero_exitcode_as_oom self.nonzero_exitcode_as_oom

View File

@ -42,7 +42,6 @@ import pandas as pd
import psutil import psutil
import pyarrow as arrow import pyarrow as arrow
import pyarrow.parquet as parquet import pyarrow.parquet as parquet
import ray
from loguru import logger from loguru import logger
from smallpond.common import ( from smallpond.common import (
@ -642,10 +641,6 @@ class Task(WorkItem):
# implementor can use this variable as a checkpoint and restore from it after interrupted # implementor can use this variable as a checkpoint and restore from it after interrupted
self.runtime_state = None self.runtime_state = None
# if the task is executed by ray, this is the reference to the output dataset
# do not use this variable directly, use `self.run_on_ray()` instead
self._dataset_ref: Optional[ray.ObjectRef] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.key}.{self.sched_epoch}.{self.retry_count},{self.node_id}" return f"{self.key}.{self.sched_epoch}.{self.retry_count},{self.node_id}"
@ -1104,97 +1099,6 @@ class Task(WorkItem):
for name, value in metrics.items(): for name, value in metrics.items():
self.perf_metrics[name] += value self.perf_metrics[name] += value
def run_on_ray(self) -> ray.ObjectRef:
"""
Run the task on Ray.
Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset.
"""
if self._dataset_ref is not None:
# already started
return self._dataset_ref
# read the output dataset if the task has already finished
if os.path.exists(self.ray_dataset_path):
logger.info(f"task {self.key} already finished, skipping")
output = load(self.ray_dataset_path)
self._dataset_ref = ray.put(output)
return self._dataset_ref
task = copy.copy(self)
task.input_deps = {dep_key: None for dep_key in task.input_deps}
@ray.remote
def exec_task(task: Task, *inputs: DataSet) -> DataSet:
import multiprocessing as mp
import os
from pathlib import Path
from loguru import logger
# ray use a process pool to execute tasks
# we set the current process name to the task name
# so that we can see task name in the logs
mp.current_process().name = task.key
# probe the retry count
task.retry_count = 0
while os.path.exists(task.ray_marker_path):
task.retry_count += 1
if task.retry_count > DEFAULT_MAX_RETRY_COUNT:
raise RuntimeError(
f"task {task.key} failed after {task.retry_count} retries"
)
if task.retry_count > 0:
logger.warning(
f"task {task.key} is being retried for the {task.retry_count}th time"
)
# create the marker file
Path(task.ray_marker_path).touch()
# put the inputs into the task
assert len(inputs) == len(task.input_deps)
task.input_datasets = list(inputs)
# execute the task
status = task.exec()
if status != WorkStatus.SUCCEED:
raise task.exception or RuntimeError(
f"task {task.key} failed with status {status}"
)
# dump the output dataset atomically
os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True)
dump(task.output, task.ray_dataset_path, atomic_write=True)
return task.output
# this shows as {"name": ...} in timeline
exec_task._function_name = repr(task)
remote_function = exec_task.options(
# ray task name
# do not include task id so that they can be grouped by node in ray dashboard
name=f"{task.node_id}.{self.__class__.__name__}",
num_cpus=self.cpu_limit,
num_gpus=self.gpu_limit,
memory=int(self.memory_limit),
# note: `exec_on_scheduler` is ignored here,
# because dataset is distributed on ray
)
try:
self._dataset_ref = remote_function.remote(
task, *[dep.run_on_ray() for dep in self.input_deps.values()]
)
except RuntimeError as e:
if (
"SimpleQueue objects should only be shared between processes through inheritance"
in str(e)
):
raise RuntimeError(
f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n"
f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
) from e
raise e
return self._dataset_ref
class ExecSqlQueryMixin(Task): class ExecSqlQueryMixin(Task):

View File

@ -21,368 +21,7 @@ import graphviz.backend.execute
from loguru import logger from loguru import logger
import smallpond import smallpond
from smallpond.execution.manager import JobManager
from smallpond.execution.task import JobId, RuntimeContext from smallpond.execution.task import JobId, RuntimeContext
from smallpond.logical.node import Context from smallpond.logical.node import Context
from smallpond.platform import Platform, get_platform from smallpond.platform import Platform, get_platform
class SessionBase:
def __init__(self, **kwargs):
"""
Create a smallpond environment.
"""
super().__init__()
self._ctx = Context()
self.config, self._platform = Config.from_args_and_env(**kwargs)
# construct runtime context for Tasks
runtime_ctx = RuntimeContext(
job_id=JobId(hex=self.config.job_id),
job_time=self.config.job_time,
data_root=self.config.data_root,
num_executors=self.config.num_executors,
bind_numa_node=self.config.bind_numa_node,
shared_log_root=self._platform.shared_log_root(),
)
self._runtime_ctx = runtime_ctx
# if `spawn` is specified, spawn a job and exit
if os.environ.get("SP_SPAWN") == "1":
self._spawn_self()
exit(0)
self._runtime_ctx.initialize(exec_id=socket.gethostname())
logger.info(f"using platform: {self._platform}")
logger.info(f"command-line arguments: {' '.join(sys.argv)}")
logger.info(f"session config: {self.config}")
def setup_worker():
runtime_ctx._init_logs(
exec_id=socket.gethostname(), capture_stdout_stderr=True
)
if self.config.ray_address is None:
# find the memory allocator
if self.config.memory_allocator == "system":
malloc_path = ""
elif self.config.memory_allocator == "jemalloc":
malloc_path = shutil.which("libjemalloc.so.2")
assert malloc_path is not None, "jemalloc is not installed"
elif self.config.memory_allocator == "mimalloc":
malloc_path = shutil.which("libmimalloc.so.2.1")
assert malloc_path is not None, "mimalloc is not installed"
else:
raise ValueError(
f"unsupported memory allocator: {self.config.memory_allocator}"
)
memory_purge_delay = 10000
# start ray head node
# for ray head node to access grafana
os.environ["RAY_GRAFANA_HOST"] = "http://localhost:8122"
self._ray_address = ray.init(
# start a new local cluster
address="local",
# disable local CPU resource if not running on localhost
num_cpus=(
0
if self.config.num_executors > 0
else self._runtime_ctx.usable_cpu_count
),
# set the memory limit to the available memory size
_memory=self._runtime_ctx.usable_memory_size,
# setup logging for workers
log_to_driver=False,
runtime_env={
"worker_process_setup_hook": setup_worker,
"env_vars": {
"LD_PRELOAD": malloc_path,
"MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16",
"MIMALLOC_PURGE_DELAY": f"{memory_purge_delay}",
"ARROW_DEFAULT_MEMORY_POOL": self.config.memory_allocator,
"ARROW_IO_THREADS": "2",
"OMP_NUM_THREADS": "2",
"POLARS_MAX_THREADS": "2",
"NUMEXPR_MAX_THREADS": "2",
"RAY_PROFILING": "1",
},
},
dashboard_host="0.0.0.0",
dashboard_port=8008,
# for prometheus to scrape metrics
_metrics_export_port=8080,
).address_info["gcs_address"]
logger.info(f"started ray cluster at {self._ray_address}")
self._prometheus_process = self._start_prometheus()
self._grafana_process = self._start_grafana()
else:
self._ray_address = self.config.ray_address
self._prometheus_process = None
self._grafana_process = None
logger.info(f"connected to ray cluster at {self._ray_address}")
# start workers
if self.config.num_executors > 0:
# override configs
kwargs["job_id"] = self.config.job_id
self._job_names = self._platform.start_job(
self.config.num_executors,
entrypoint=os.path.join(os.path.dirname(__file__), "worker.py"),
args=[
f"--ray_address={self._ray_address}",
f"--log_dir={self._runtime_ctx.log_root}",
*(["--bind_numa_node"] if self.config.bind_numa_node else []),
],
extra_opts=kwargs,
)
else:
self._job_names = []
# spawn a thread to periodically dump metrics
self._stop_event = threading.Event()
self._dump_thread = threading.Thread(
name="dump_thread", target=self._dump_periodically, daemon=True
)
self._dump_thread.start()
def shutdown(self):
"""
Shutdown the session.
"""
logger.info("shutting down session")
self._stop_event.set()
# stop all jobs
for job_name in self._job_names:
self._platform.stop_job(job_name)
self._job_names = []
self._dump_thread.join()
if self.config.ray_address is None:
ray.shutdown()
if self._prometheus_process is not None:
self._prometheus_process.terminate()
self._prometheus_process.wait()
self._prometheus_process = None
logger.info("stopped prometheus")
if self._grafana_process is not None:
self._grafana_process.terminate()
self._grafana_process.wait()
self._grafana_process = None
logger.info("stopped grafana")
def _spawn_self(self):
"""
Spawn a new job to run the current script.
"""
self._platform.start_job(
num_nodes=1,
entrypoint=sys.argv[0],
args=sys.argv[1:],
extra_opts=dict(
tags=["smallpond", "scheduler", smallpond.__version__],
),
envs={
k: v
for k, v in os.environ.items()
if k.startswith("SP_") and k != "SP_SPAWN"
},
)
def _start_prometheus(self) -> Optional[subprocess.Popen]:
"""
Start prometheus server if it exists.
"""
prometheus_path = shutil.which("prometheus")
if prometheus_path is None:
logger.warning("prometheus is not found")
return None
os.makedirs(f"{self._runtime_ctx.log_root}/prometheus", exist_ok=True)
proc = subprocess.Popen(
[
prometheus_path,
"--config.file=/tmp/ray/session_latest/metrics/prometheus/prometheus.yml",
f"--storage.tsdb.path={self._runtime_ctx.log_root}/prometheus/data",
],
stderr=open(f"{self._runtime_ctx.log_root}/prometheus/prometheus.log", "w"),
)
logger.info("started prometheus")
return proc
def _start_grafana(self) -> Optional[subprocess.Popen]:
"""
Start grafana server if it exists.
"""
homepath = self._platform.grafana_homepath()
if homepath is None:
logger.warning("grafana is not found")
return None
os.makedirs(f"{self._runtime_ctx.log_root}/grafana", exist_ok=True)
proc = subprocess.Popen(
[
shutil.which("grafana"),
"server",
"--config",
"/tmp/ray/session_latest/metrics/grafana/grafana.ini",
"-homepath",
homepath,
"web",
],
stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"),
env={
"GF_SERVER_HTTP_PORT": "8122", # redirect to an available port
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST")
or "http://localhost:8122",
"GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data",
},
)
logger.info(f"started grafana at http://localhost:8122")
return proc
@property
def runtime_ctx(self) -> RuntimeContext:
return self._runtime_ctx
def graph(self) -> Digraph:
"""
Get the logical plan graph.
"""
# implemented in Session class
raise NotImplementedError("graph")
def dump_graph(self, path: Optional[str] = None):
"""
Dump the logical plan graph to a file.
"""
path = path or os.path.join(self.runtime_ctx.log_root, "graph")
try:
self.graph().render(path, format="png")
logger.debug(f"dumped graph to {path}")
except graphviz.backend.execute.ExecutableNotFound as e:
logger.warning(f"graphviz is not installed, skipping graph dump")
def dump_timeline(self, path: Optional[str] = None):
"""
Dump the task timeline to a file.
"""
path = path or os.path.join(self.runtime_ctx.log_root, "timeline")
# the default timeline is grouped by worker
exec_path = f"{path}_exec"
ray.timeline(exec_path)
logger.debug(f"dumped timeline to {exec_path}")
# generate another timeline grouped by node
with open(exec_path) as f:
records = json.load(f)
new_records = []
for record in records:
# swap record name and pid-tid
name = record["name"]
try:
node_id = name.split(",")[-1]
task_id = name.split("-")[1].split(".")[0]
task_name = name.split("-")[0]
record["pid"] = f"{node_id}-{task_name}"
record["tid"] = f"task {task_id}"
new_records.append(record)
except Exception:
# filter out other records
pass
node_path = f"{path}_plan"
with open(node_path, "w") as f:
json.dump(new_records, f)
logger.debug(f"dumped timeline to {node_path}")
def _summarize_task(self) -> Tuple[int, int]:
# implemented in Session class
raise NotImplementedError("summarize_task")
def _dump_periodically(self):
"""
Dump the graph and timeline every minute.
Set `self._stop_event` to have a final dump and stop this thread.
"""
while not self._stop_event.is_set():
self._stop_event.wait(60)
self.dump_graph()
self.dump_timeline()
num_total_tasks, num_finished_tasks = self._summarize_task()
percent = (
num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
)
logger.info(
f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)"
)
@dataclass
class Config:
"""
Configuration for a session.
"""
job_id: str # JOBID
job_time: datetime # JOB_TIME
data_root: str # DATA_ROOT
num_executors: int # NUM_NODES_TOTAL
ray_address: Optional[str] # RAY_ADDRESS
bind_numa_node: bool # BIND_NUMA_NODE
memory_allocator: str # MEMORY_ALLOCATOR
remove_output_root: bool
@staticmethod
def from_args_and_env(
platform: Optional[str] = None,
job_id: Optional[str] = None,
job_time: Optional[datetime] = None,
data_root: Optional[str] = None,
num_executors: Optional[int] = None,
ray_address: Optional[str] = None,
bind_numa_node: Optional[bool] = None,
memory_allocator: Optional[str] = None,
_remove_output_root: bool = True,
**kwargs,
) -> Config:
"""
Load config from arguments and environment variables.
If not specified, use the default value.
"""
def get_env(key: str, type: type = str):
"""
Get an environment variable and convert it to the given type.
If the variable is not set, return None.
"""
value = os.environ.get(f"SP_{key}")
return type(value) if value is not None else None
platform = get_platform(get_env("PLATFORM") or platform)
job_id = get_env("JOBID") or job_id or platform.default_job_id()
job_time = (
get_env("JOB_TIME", datetime.fromisoformat)
or job_time
or platform.default_job_time()
)
data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root()
num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0
ray_address = get_env("RAY_ADDRESS") or ray_address
bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node
memory_allocator = (
get_env("MEMORY_ALLOCATOR")
or memory_allocator
or platform.default_memory_allocator()
)
config = Config(
job_id=job_id,
job_time=job_time,
data_root=data_root,
num_executors=num_executors,
ray_address=ray_address,
bind_numa_node=bind_numa_node,
memory_allocator=memory_allocator,
remove_output_root=_remove_output_root,
)
return config, platform

View File

@ -210,6 +210,7 @@ class TestFabric(unittest.TestCase):
self.sched_states = self.queue_manager.Queue() self.sched_states = self.queue_manager.Queue()
scheduler = Scheduler( scheduler = Scheduler(
ctx=runtime_ctx,
max_retry_count=max_retry_count, max_retry_count=max_retry_count,
max_fail_count=max_fail_count, max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry, prioritize_retry=prioritize_retry,