Signed-off-by: Runji Wang <runji@deepseek.com>
This commit is contained in:
Runji Wang 2025-03-02 22:28:52 +08:00
parent 5a56a052bf
commit 9492a2872a
7 changed files with 293 additions and 631 deletions

View File

@ -13,22 +13,26 @@ import ray
import ray.exceptions
from loguru import logger
from smallpond.execution.manager import JobManager
from smallpond.execution.task import Task
from smallpond.io.filesystem import remove_path
from smallpond.logical.dataset import *
from smallpond.logical.node import *
from smallpond.logical.optimizer import Optimizer
from smallpond.logical.planner import Planner
from smallpond.session import SessionBase
class Session(SessionBase):
class Session:
# Extended session class with additional methods to create DataFrames.
def __init__(self, **kwargs):
super().__init__(**kwargs)
"""
Create a smallpond environment.
"""
self._job_manager = JobManager(**kwargs)
self._ctx = Context()
self._runtime_ctx = self._job_manager.runtime_ctx
self._nodes: List[Node] = []
self._node_to_tasks: Dict[Node, List[Task]] = {}
"""
When a DataFrame is evaluated, the tasks of the logical plan are stored here.
@ -158,52 +162,6 @@ class Session(SessionBase):
return
self._shutdown_called = True
# log status
finished = self._all_tasks_finished()
with open(self._runtime_ctx.job_status_path, "a") as fout:
status = "success" if finished else "failure"
fout.write(f"{status}@{datetime.now():%Y-%m-%d-%H-%M-%S}\n")
# clean up runtime directories if success
if finished:
logger.info("all tasks are finished, cleaning up")
self._runtime_ctx.cleanup(remove_output_root=self.config.remove_output_root)
else:
logger.warning("tasks are not finished!")
super().shutdown()
def _summarize_task(self) -> Tuple[int, int]:
"""
Return the total number of tasks and the number of tasks that are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
if task._dataset_ref is not None
]
ready_tasks, _ = ray.wait(
dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False
)
return len(dataset_refs), len(ready_tasks)
def _all_tasks_finished(self) -> bool:
"""
Check if all tasks are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
]
try:
ray.get(dataset_refs, timeout=0)
except Exception:
# GetTimeoutError is raised if any task is not finished
# RuntimeError is raised if any task failed
return False
return True
class DataFrame:
@ -216,7 +174,7 @@ class DataFrame:
def __init__(self, session: Session, plan: Node, recompute: bool = False):
self.session = session
self.plan = plan
self.optimized_plan: Optional[Node] = None
# self.optimized_plan: Optional[Node] = None
self.need_recompute = recompute
"""Whether to recompute the data regardless of whether it's already computed."""
@ -229,17 +187,17 @@ class DataFrame:
"""
Get or create tasks to compute the data.
"""
# optimize the plan
if self.optimized_plan is None:
logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
self.optimized_plan = Optimizer(
exclude_nodes=set(self.session._node_to_tasks.keys())
).visit(self.plan)
logger.info(
f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
)
# # optimize the plan
# if self.optimized_plan is None:
# logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
# self.optimized_plan = Optimizer(
# exclude_nodes=set(self.session._node_to_tasks.keys())
# ).visit(self.plan)
# logger.info(
# f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
# )
# return the tasks if already created
if tasks := self.session._node_to_tasks.get(self.optimized_plan):
if tasks := self.session._node_to_tasks.get(self.plan):
return tasks
# remove all completed task files if recompute is needed
@ -247,16 +205,16 @@ class DataFrame:
remove_path(
os.path.join(
self.session._runtime_ctx.completed_task_dir,
str(self.optimized_plan.id),
str(self.plan.id),
)
)
logger.info(f"cleared all results of {self.optimized_plan!r}")
logger.info(f"cleared all results of {self.plan!r}")
# create tasks for the optimized plan
planner = Planner(self.session._runtime_ctx)
# let planner update self.session._node_to_tasks
planner.node_to_tasks = self.session._node_to_tasks
return planner.visit(self.optimized_plan)
return planner.visit(self.plan)
def is_computed(self) -> bool:
"""

103
smallpond/execution/_ray.py Normal file
View File

@ -0,0 +1,103 @@
import copy
import os
import ray
from loguru import logger
from smallpond.common import DEFAULT_MAX_RETRY_COUNT
from smallpond.execution.task import Task
from smallpond.execution.workqueue import WorkStatus
from smallpond.io.filesystem import dump, load
from smallpond.logical.dataset import DataSet
def run_on_ray(task: Task) -> ray.ObjectRef:
"""
Run the task on Ray.
Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset.
A `_dataset_ref` attribute is added to the task to store the reference.
"""
if task._dataset_ref is not None:
# already started
return task._dataset_ref
# read the output dataset if the task has already finished
if os.path.exists(task.ray_dataset_path):
logger.info(f"task {task.key} already finished, skipping")
output = load(task.ray_dataset_path)
task._dataset_ref = ray.put(output)
return task._dataset_ref
task = copy.copy(task)
task.input_deps = {dep_key: None for dep_key in task.input_deps}
@ray.remote
def exec_task(task: Task, *inputs: DataSet) -> DataSet:
import multiprocessing as mp
import os
from pathlib import Path
from loguru import logger
# ray use a process pool to execute tasks
# we set the current process name to the task name
# so that we can see task name in the logs
mp.current_process().name = task.key
# probe the retry count
task.retry_count = 0
while os.path.exists(task.ray_marker_path):
task.retry_count += 1
if task.retry_count > DEFAULT_MAX_RETRY_COUNT:
raise RuntimeError(
f"task {task.key} failed after {task.retry_count} retries"
)
if task.retry_count > 0:
logger.warning(
f"task {task.key} is being retried for the {task.retry_count}th time"
)
# create the marker file
Path(task.ray_marker_path).touch()
# put the inputs into the task
assert len(inputs) == len(task.input_deps)
task.input_datasets = list(inputs)
# execute the task
status = task.exec()
if status != WorkStatus.SUCCEED:
raise task.exception or RuntimeError(
f"task {task.key} failed with status {status}"
)
# dump the output dataset atomically
os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True)
dump(task.output, task.ray_dataset_path, atomic_write=True)
return task.output
# this shows as {"name": ...} in timeline
exec_task._function_name = repr(task)
remote_function = exec_task.options(
# ray task name
# do not include task id so that they can be grouped by node in ray dashboard
name=f"{task.node_id}.{task.__class__.__name__}",
num_cpus=task.cpu_limit,
num_gpus=task.gpu_limit,
memory=int(task.memory_limit),
# note: `exec_on_scheduler` is ignored here,
# because dataset is distributed on ray
)
try:
task._dataset_ref = remote_function.remote(
task, *[run_on_ray(dep) for dep in task.input_deps.values()]
)
except RuntimeError as e:
if (
"SimpleQueue objects should only be shared between processes through inheritance"
in str(e)
):
raise RuntimeError(
f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n"
f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
) from e
raise e
return task._dataset_ref

View File

@ -212,14 +212,10 @@ class JobManager(object):
sched_state_observers.insert(0, sched_state_exporter)
if os.path.exists(self.runtime_ctx.sched_state_path):
logger.warning(
f"loading scheduler state from: {self.runtime_ctx.sched_state_path}"
)
self.scheduler: Scheduler = load(self.runtime_ctx.sched_state_path)
self.scheduler.sched_epoch += 1
self.scheduler.sched_state_observers = sched_state_observers
self.scheduler = Scheduler.recover_from_file(self.runtime_ctx.sched_state_path, sched_state_observers)
else:
self.scheduler = Scheduler(
ctx=self.runtime_ctx,
max_retry_count=max_retry_count,
max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry,

View File

@ -1,5 +1,6 @@
import copy
import cProfile
from datetime import datetime
import itertools
import multiprocessing as mp
import os
@ -12,7 +13,18 @@ from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
import numpy as np
from loguru import logger
@ -40,7 +52,7 @@ from smallpond.execution.workqueue import (
WorkQueueInMemory,
WorkQueueOnFilesystem,
)
from smallpond.io.filesystem import dump, remove_path
from smallpond.io.filesystem import dump, load, remove_path
from smallpond.logical.node import LogicalPlan, Node
from smallpond.utility import cprofile_to_string
@ -359,10 +371,9 @@ class Scheduler(object):
"""
large_num_nontrivial_tasks = 200 if pytest_running() else 20000
StateCallback = Callable[["Scheduler"], Any]
class StateObserver(object):
def __init__(self, callback: "Scheduler.StateCallback" = None) -> None:
def __init__(self, callback: Callable[["Scheduler"], Any] = None) -> None:
assert callback is None or isinstance(callback, Callable)
self.enabled = True
self.callback = callback
@ -380,6 +391,7 @@ class Scheduler(object):
def __init__(
self,
ctx: RuntimeContext,
*,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
@ -390,18 +402,46 @@ class Scheduler(object):
remove_output_root: bool = False,
sched_state_observers: Optional[List[StateObserver]] = None,
) -> None:
"""
Initialize the scheduler.
Parameters
----------
ctx: RuntimeContext
The runtime context.
max_retry_count: int, optional
The maximum retry count. Default to 5.
max_fail_count: int, optional
The maximum fail count. Default to 3.
prioritize_retry: bool, optional
Whether to prioritize retry. Default to False.
speculative_exec: Literal["disable", "enable", "aggressive"], optional
If "enable", long-running tasks will be rescheduled on other executors.
If "aggressive", it will be more aggressive to reschedule long-running tasks.
If "disable", no speculative execution will be performed.
Default to "enable".
stop_executor_on_failure: bool, optional
Whether to stop the executor on failure. Default to False.
nonzero_exitcode_as_oom: bool, optional
Whether to treat non-zero exit code as out-of-memory. Default to False.
remove_output_root: bool, optional
Whether to remove the output root on exit. Default to False.
sched_state_observers: Optional[List[StateObserver]], optional
The state observers.
"""
# configs
self.ctx = ctx
self.max_retry_count = max_retry_count
self.max_fail_count = max_fail_count
self.standalone_mode = self.ctx.num_executors == 0
self.prioritize_retry = prioritize_retry
self.disable_speculative_exec = speculative_exec == "disable"
self.aggressive_speculative_exec = speculative_exec == "aggressive"
self.speculative_exec: Literal["disable", "enable", "aggressive"] = (
speculative_exec
)
self.stop_executor_on_failure = stop_executor_on_failure
self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom
self.remove_output_root = remove_output_root
self.sched_state_observers: List[Scheduler.StateObserver] = (
sched_state_observers or []
)
self.sched_state_observers = sched_state_observers or []
self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2
# task states
self.local_queue: List[Task] = []
@ -416,30 +456,47 @@ class Scheduler(object):
self.local_executor = LocalExecutor.create(self.ctx, "localhost")
self.available_executors = {self.local_executor.id: self.local_executor}
# other runtime states
self.sched_running = False
self.sched_start_time = 0
self.failure = False
self.sched_start_time = time.time()
self.last_executor_probe_time = 0
self.last_state_notify_time = 0
self.probe_epoch = 0
self.sched_epoch = 0
self._post_init()
@staticmethod
def recover_from_file(
sched_state_path: str, sched_state_observers: List[StateObserver]
) -> "Scheduler":
"""
Recover the scheduler from the previous run.
"""
logger.warning(f"loading scheduler state from: {sched_state_path}")
self: Scheduler = load(sched_state_path)
self.sched_epoch += 1
self.prioritize_retry = True
# observers are not pickled, so we need to re-add them
self.sched_state_observers = sched_state_observers
self._post_init()
def __getstate__(self):
state = self.__dict__.copy()
del state["sched_state_observers"]
del state["profiler"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.sched_state_observers = []
self.profiler = None
@property
def elapsed_time(self):
return time.time() - self.sched_start_time
@property
def success(self) -> bool:
return self.exec_plan.root_task.key in self.succeeded_tasks
@property
def progress(self) -> Tuple[int, int, float]:
num_succeeded = len(self.succeeded_nontrivial_tasks)
@ -582,7 +639,7 @@ class Scheduler(object):
for executor in self.working_executors:
for idx, item in enumerate(executor.running_works.values()):
aggressive_retry = (
self.aggressive_speculative_exec
self.speculative_exec == "aggressive"
and len(self.good_executors) >= self.ctx.num_executors
)
short_sched_queue = len(self.sched_queue) < len(self.good_executors)
@ -644,7 +701,7 @@ class Scheduler(object):
for executor in self.working_executors:
executor.probe(self.probe_epoch)
# start speculative execution of tasks
if not self.disable_speculative_exec:
if self.speculative_exec != "disable":
self.start_speculative_execution()
def update_executor_states(self):
@ -777,7 +834,7 @@ class Scheduler(object):
return task
@logger.catch(reraise=pytest_running(), message="failed to clean temp files")
def clean_temp_files(self, pool: ThreadPoolExecutor):
def clean_temp_files(self):
remove_path(self.ctx.queue_root)
remove_path(self.ctx.temp_root)
remove_path(self.ctx.staging_root)
@ -786,7 +843,10 @@ class Scheduler(object):
logger.info(
f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..."
)
assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks))
with ThreadPoolExecutor(32) as pool:
assert list(
pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)
)
@logger.catch(reraise=pytest_running(), message="failed to export task metrics")
def export_task_metrics(self):
@ -963,55 +1023,45 @@ class Scheduler(object):
def log_current_status(self):
with open(self.ctx.job_status_path, "w") as fout:
if self.sched_running:
status = "running"
elif self.success:
status = "success"
else:
if self.failure:
status = "failure"
fout.write(f"{status}@{int(time.time())}")
else:
status = "running"
fout.write(f"{status}@{datetime.now().isoformat()}")
def run(self, exec_plan: ExecutionPlan) -> bool:
def _post_init(self):
"""
Run the execution plan.
Common initialization after startup or recovery.
"""
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
logger.info(
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"
)
perf_profile = None
if self.ctx.enable_profiling:
perf_profile = cProfile.Profile()
perf_profile.enable()
self.profiler = cProfile.Profile()
self.profiler.enable()
self.sched_running = True
self.sched_start_time = time.time()
self.last_executor_probe_time = 0
self.last_state_notify_time = 0
self.prioritize_retry |= self.sched_epoch > 0
if self.local_queue or self.sched_queue:
pending_tasks = [
item
for item in self.local_queue + self.sched_queue
if isinstance(item, Task)
]
self.local_queue.clear()
self.sched_queue.clear()
logger.info(
f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..."
)
self.try_enqueue(pending_tasks)
with ThreadPoolExecutor(32) as pool:
self.sched_running = True
self.sched_start_time = time.time()
self.last_executor_probe_time = 0
self.last_state_notify_time = 0
self.prioritize_retry |= self.sched_epoch > 0
if self.local_queue or self.sched_queue:
pending_tasks = [
item
for item in self.local_queue + self.sched_queue
if isinstance(item, Task)
]
self.local_queue.clear()
self.sched_queue.clear()
logger.info(
f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..."
)
self.try_enqueue(pending_tasks)
if self.sched_epoch == 0:
leaf_tasks = self.exec_plan.leaves
logger.info(
f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..."
)
self.try_enqueue(leaf_tasks)
self.log_overall_progress()
while (num_finished_tasks := self.process_finished_tasks(pool)) > 0:
logger.info(
@ -1019,54 +1069,64 @@ class Scheduler(object):
)
self.log_overall_progress()
earlier_running_tasks = [
item for item in self.running_works if isinstance(item, Task)
]
if earlier_running_tasks:
logger.info(
f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..."
)
self.try_enqueue(earlier_running_tasks)
earlier_running_tasks = [
item for item in self.running_works if isinstance(item, Task)
]
if earlier_running_tasks:
logger.info(
f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..."
)
self.try_enqueue(earlier_running_tasks)
self.suspend_good_executors()
self.add_state_observer(
Scheduler.StateObserver(Scheduler.log_current_status)
)
self.add_state_observer(
Scheduler.StateObserver(Scheduler.export_timeline_figs)
)
self.suspend_good_executors()
self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status))
self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs))
self.notify_state_observers(force_notify=True)
def run(self, exec_plan: ExecutionPlan) -> bool:
"""
Run the execution plan.
"""
leaf_tasks = exec_plan.leaves
logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...")
self.try_enqueue(leaf_tasks)
try:
with ThreadPoolExecutor(32) as pool:
self.local_executor.start(pool)
self.sched_loop(pool, exec_plan.root_task)
finally:
logger.info(f"schedule loop stopped")
self.notify_state_observers(force_notify=True)
try:
self.local_executor.start(pool)
self.sched_loop(pool)
finally:
logger.info(f"schedule loop stopped")
self.sched_running = False
self.notify_state_observers(force_notify=True)
self.export_task_metrics()
self.stop_executors()
logger.success(f"final output path: {exec_plan.final_output_path}")
logger.info(
f"analyzed plan:{os.linesep}{exec_plan.analyzed_logical_plan.explain_str()}"
)
# if --output_path is specified, remove the output root as well
if self.remove_output_root or self.ctx.final_output_path:
remove_path(self.ctx.staging_root)
remove_path(self.ctx.output_root)
return True
if self.success:
self.clean_temp_files(pool)
logger.success(f"final output path: {self.exec_plan.final_output_path}")
logger.info(
f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}"
)
def cleanup(self):
self.export_task_metrics()
self.stop_executors()
if perf_profile is not None:
# if --output_path is specified, remove the output root as well
if self.remove_output_root:
remove_path(self.ctx.staging_root)
remove_path(self.ctx.output_root)
if not self.failure:
self.clean_temp_files()
if self.ctx.final_output_path:
remove_path(self.ctx.final_output_path)
if self.profiler is not None:
logger.debug(
f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}"
f"scheduler perf profile:{os.linesep}{cprofile_to_string(self.profiler)}"
)
logger.info(f"scheduler of job {self.ctx.job_id} exits")
logger.complete()
return self.success
def try_enqueue(self, tasks: Union[Iterable[Task], Task]):
tasks = tasks if isinstance(tasks, Iterable) else [tasks]
@ -1092,15 +1152,15 @@ class Scheduler(object):
else:
self.sched_queue.append(task)
def sched_loop(self, pool: ThreadPoolExecutor) -> bool:
def sched_loop(self, pool: ThreadPoolExecutor, task: Task):
"""
Run the scheduler loop until the task is finished or failed.
"""
has_progress = True
do_notify = False
if self.success:
logger.success(f"job already succeeded, stopping scheduler ...")
return True
while self.sched_running:
while not self.failure and self.tasks[task.key].status == WorkStatus.INCOMPLETE:
self.probe_executors()
self.update_executor_states()
@ -1153,9 +1213,6 @@ class Scheduler(object):
has_progress |= self.process_finished_tasks(pool) > 0
# out of loop
return self.success
def dispatch_tasks(self, pool: ThreadPoolExecutor):
# sort pending tasks
item_sort_key = (
@ -1240,6 +1297,10 @@ class Scheduler(object):
return num_dispatched_tasks
def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int:
"""
Process finished tasks from all executors.
Return the number of finished tasks.
"""
pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values())
num_finished_tasks = 0
@ -1309,7 +1370,7 @@ class Scheduler(object):
f"task failed too many times: {finished_task}, stopping ..."
)
self.stop_executors()
self.sched_running = False
self.failure = True
if not executor.local and finished_task.oom(
self.nonzero_exitcode_as_oom

View File

@ -42,7 +42,6 @@ import pandas as pd
import psutil
import pyarrow as arrow
import pyarrow.parquet as parquet
import ray
from loguru import logger
from smallpond.common import (
@ -642,10 +641,6 @@ class Task(WorkItem):
# implementor can use this variable as a checkpoint and restore from it after interrupted
self.runtime_state = None
# if the task is executed by ray, this is the reference to the output dataset
# do not use this variable directly, use `self.run_on_ray()` instead
self._dataset_ref: Optional[ray.ObjectRef] = None
def __repr__(self) -> str:
return f"{self.key}.{self.sched_epoch}.{self.retry_count},{self.node_id}"
@ -1104,97 +1099,6 @@ class Task(WorkItem):
for name, value in metrics.items():
self.perf_metrics[name] += value
def run_on_ray(self) -> ray.ObjectRef:
"""
Run the task on Ray.
Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset.
"""
if self._dataset_ref is not None:
# already started
return self._dataset_ref
# read the output dataset if the task has already finished
if os.path.exists(self.ray_dataset_path):
logger.info(f"task {self.key} already finished, skipping")
output = load(self.ray_dataset_path)
self._dataset_ref = ray.put(output)
return self._dataset_ref
task = copy.copy(self)
task.input_deps = {dep_key: None for dep_key in task.input_deps}
@ray.remote
def exec_task(task: Task, *inputs: DataSet) -> DataSet:
import multiprocessing as mp
import os
from pathlib import Path
from loguru import logger
# ray use a process pool to execute tasks
# we set the current process name to the task name
# so that we can see task name in the logs
mp.current_process().name = task.key
# probe the retry count
task.retry_count = 0
while os.path.exists(task.ray_marker_path):
task.retry_count += 1
if task.retry_count > DEFAULT_MAX_RETRY_COUNT:
raise RuntimeError(
f"task {task.key} failed after {task.retry_count} retries"
)
if task.retry_count > 0:
logger.warning(
f"task {task.key} is being retried for the {task.retry_count}th time"
)
# create the marker file
Path(task.ray_marker_path).touch()
# put the inputs into the task
assert len(inputs) == len(task.input_deps)
task.input_datasets = list(inputs)
# execute the task
status = task.exec()
if status != WorkStatus.SUCCEED:
raise task.exception or RuntimeError(
f"task {task.key} failed with status {status}"
)
# dump the output dataset atomically
os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True)
dump(task.output, task.ray_dataset_path, atomic_write=True)
return task.output
# this shows as {"name": ...} in timeline
exec_task._function_name = repr(task)
remote_function = exec_task.options(
# ray task name
# do not include task id so that they can be grouped by node in ray dashboard
name=f"{task.node_id}.{self.__class__.__name__}",
num_cpus=self.cpu_limit,
num_gpus=self.gpu_limit,
memory=int(self.memory_limit),
# note: `exec_on_scheduler` is ignored here,
# because dataset is distributed on ray
)
try:
self._dataset_ref = remote_function.remote(
task, *[dep.run_on_ray() for dep in self.input_deps.values()]
)
except RuntimeError as e:
if (
"SimpleQueue objects should only be shared between processes through inheritance"
in str(e)
):
raise RuntimeError(
f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n"
f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
) from e
raise e
return self._dataset_ref
class ExecSqlQueryMixin(Task):

View File

@ -21,368 +21,7 @@ import graphviz.backend.execute
from loguru import logger
import smallpond
from smallpond.execution.manager import JobManager
from smallpond.execution.task import JobId, RuntimeContext
from smallpond.logical.node import Context
from smallpond.platform import Platform, get_platform
class SessionBase:
def __init__(self, **kwargs):
"""
Create a smallpond environment.
"""
super().__init__()
self._ctx = Context()
self.config, self._platform = Config.from_args_and_env(**kwargs)
# construct runtime context for Tasks
runtime_ctx = RuntimeContext(
job_id=JobId(hex=self.config.job_id),
job_time=self.config.job_time,
data_root=self.config.data_root,
num_executors=self.config.num_executors,
bind_numa_node=self.config.bind_numa_node,
shared_log_root=self._platform.shared_log_root(),
)
self._runtime_ctx = runtime_ctx
# if `spawn` is specified, spawn a job and exit
if os.environ.get("SP_SPAWN") == "1":
self._spawn_self()
exit(0)
self._runtime_ctx.initialize(exec_id=socket.gethostname())
logger.info(f"using platform: {self._platform}")
logger.info(f"command-line arguments: {' '.join(sys.argv)}")
logger.info(f"session config: {self.config}")
def setup_worker():
runtime_ctx._init_logs(
exec_id=socket.gethostname(), capture_stdout_stderr=True
)
if self.config.ray_address is None:
# find the memory allocator
if self.config.memory_allocator == "system":
malloc_path = ""
elif self.config.memory_allocator == "jemalloc":
malloc_path = shutil.which("libjemalloc.so.2")
assert malloc_path is not None, "jemalloc is not installed"
elif self.config.memory_allocator == "mimalloc":
malloc_path = shutil.which("libmimalloc.so.2.1")
assert malloc_path is not None, "mimalloc is not installed"
else:
raise ValueError(
f"unsupported memory allocator: {self.config.memory_allocator}"
)
memory_purge_delay = 10000
# start ray head node
# for ray head node to access grafana
os.environ["RAY_GRAFANA_HOST"] = "http://localhost:8122"
self._ray_address = ray.init(
# start a new local cluster
address="local",
# disable local CPU resource if not running on localhost
num_cpus=(
0
if self.config.num_executors > 0
else self._runtime_ctx.usable_cpu_count
),
# set the memory limit to the available memory size
_memory=self._runtime_ctx.usable_memory_size,
# setup logging for workers
log_to_driver=False,
runtime_env={
"worker_process_setup_hook": setup_worker,
"env_vars": {
"LD_PRELOAD": malloc_path,
"MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16",
"MIMALLOC_PURGE_DELAY": f"{memory_purge_delay}",
"ARROW_DEFAULT_MEMORY_POOL": self.config.memory_allocator,
"ARROW_IO_THREADS": "2",
"OMP_NUM_THREADS": "2",
"POLARS_MAX_THREADS": "2",
"NUMEXPR_MAX_THREADS": "2",
"RAY_PROFILING": "1",
},
},
dashboard_host="0.0.0.0",
dashboard_port=8008,
# for prometheus to scrape metrics
_metrics_export_port=8080,
).address_info["gcs_address"]
logger.info(f"started ray cluster at {self._ray_address}")
self._prometheus_process = self._start_prometheus()
self._grafana_process = self._start_grafana()
else:
self._ray_address = self.config.ray_address
self._prometheus_process = None
self._grafana_process = None
logger.info(f"connected to ray cluster at {self._ray_address}")
# start workers
if self.config.num_executors > 0:
# override configs
kwargs["job_id"] = self.config.job_id
self._job_names = self._platform.start_job(
self.config.num_executors,
entrypoint=os.path.join(os.path.dirname(__file__), "worker.py"),
args=[
f"--ray_address={self._ray_address}",
f"--log_dir={self._runtime_ctx.log_root}",
*(["--bind_numa_node"] if self.config.bind_numa_node else []),
],
extra_opts=kwargs,
)
else:
self._job_names = []
# spawn a thread to periodically dump metrics
self._stop_event = threading.Event()
self._dump_thread = threading.Thread(
name="dump_thread", target=self._dump_periodically, daemon=True
)
self._dump_thread.start()
def shutdown(self):
"""
Shutdown the session.
"""
logger.info("shutting down session")
self._stop_event.set()
# stop all jobs
for job_name in self._job_names:
self._platform.stop_job(job_name)
self._job_names = []
self._dump_thread.join()
if self.config.ray_address is None:
ray.shutdown()
if self._prometheus_process is not None:
self._prometheus_process.terminate()
self._prometheus_process.wait()
self._prometheus_process = None
logger.info("stopped prometheus")
if self._grafana_process is not None:
self._grafana_process.terminate()
self._grafana_process.wait()
self._grafana_process = None
logger.info("stopped grafana")
def _spawn_self(self):
"""
Spawn a new job to run the current script.
"""
self._platform.start_job(
num_nodes=1,
entrypoint=sys.argv[0],
args=sys.argv[1:],
extra_opts=dict(
tags=["smallpond", "scheduler", smallpond.__version__],
),
envs={
k: v
for k, v in os.environ.items()
if k.startswith("SP_") and k != "SP_SPAWN"
},
)
def _start_prometheus(self) -> Optional[subprocess.Popen]:
"""
Start prometheus server if it exists.
"""
prometheus_path = shutil.which("prometheus")
if prometheus_path is None:
logger.warning("prometheus is not found")
return None
os.makedirs(f"{self._runtime_ctx.log_root}/prometheus", exist_ok=True)
proc = subprocess.Popen(
[
prometheus_path,
"--config.file=/tmp/ray/session_latest/metrics/prometheus/prometheus.yml",
f"--storage.tsdb.path={self._runtime_ctx.log_root}/prometheus/data",
],
stderr=open(f"{self._runtime_ctx.log_root}/prometheus/prometheus.log", "w"),
)
logger.info("started prometheus")
return proc
def _start_grafana(self) -> Optional[subprocess.Popen]:
"""
Start grafana server if it exists.
"""
homepath = self._platform.grafana_homepath()
if homepath is None:
logger.warning("grafana is not found")
return None
os.makedirs(f"{self._runtime_ctx.log_root}/grafana", exist_ok=True)
proc = subprocess.Popen(
[
shutil.which("grafana"),
"server",
"--config",
"/tmp/ray/session_latest/metrics/grafana/grafana.ini",
"-homepath",
homepath,
"web",
],
stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"),
env={
"GF_SERVER_HTTP_PORT": "8122", # redirect to an available port
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST")
or "http://localhost:8122",
"GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data",
},
)
logger.info(f"started grafana at http://localhost:8122")
return proc
@property
def runtime_ctx(self) -> RuntimeContext:
return self._runtime_ctx
def graph(self) -> Digraph:
"""
Get the logical plan graph.
"""
# implemented in Session class
raise NotImplementedError("graph")
def dump_graph(self, path: Optional[str] = None):
"""
Dump the logical plan graph to a file.
"""
path = path or os.path.join(self.runtime_ctx.log_root, "graph")
try:
self.graph().render(path, format="png")
logger.debug(f"dumped graph to {path}")
except graphviz.backend.execute.ExecutableNotFound as e:
logger.warning(f"graphviz is not installed, skipping graph dump")
def dump_timeline(self, path: Optional[str] = None):
"""
Dump the task timeline to a file.
"""
path = path or os.path.join(self.runtime_ctx.log_root, "timeline")
# the default timeline is grouped by worker
exec_path = f"{path}_exec"
ray.timeline(exec_path)
logger.debug(f"dumped timeline to {exec_path}")
# generate another timeline grouped by node
with open(exec_path) as f:
records = json.load(f)
new_records = []
for record in records:
# swap record name and pid-tid
name = record["name"]
try:
node_id = name.split(",")[-1]
task_id = name.split("-")[1].split(".")[0]
task_name = name.split("-")[0]
record["pid"] = f"{node_id}-{task_name}"
record["tid"] = f"task {task_id}"
new_records.append(record)
except Exception:
# filter out other records
pass
node_path = f"{path}_plan"
with open(node_path, "w") as f:
json.dump(new_records, f)
logger.debug(f"dumped timeline to {node_path}")
def _summarize_task(self) -> Tuple[int, int]:
# implemented in Session class
raise NotImplementedError("summarize_task")
def _dump_periodically(self):
"""
Dump the graph and timeline every minute.
Set `self._stop_event` to have a final dump and stop this thread.
"""
while not self._stop_event.is_set():
self._stop_event.wait(60)
self.dump_graph()
self.dump_timeline()
num_total_tasks, num_finished_tasks = self._summarize_task()
percent = (
num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
)
logger.info(
f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)"
)
@dataclass
class Config:
"""
Configuration for a session.
"""
job_id: str # JOBID
job_time: datetime # JOB_TIME
data_root: str # DATA_ROOT
num_executors: int # NUM_NODES_TOTAL
ray_address: Optional[str] # RAY_ADDRESS
bind_numa_node: bool # BIND_NUMA_NODE
memory_allocator: str # MEMORY_ALLOCATOR
remove_output_root: bool
@staticmethod
def from_args_and_env(
platform: Optional[str] = None,
job_id: Optional[str] = None,
job_time: Optional[datetime] = None,
data_root: Optional[str] = None,
num_executors: Optional[int] = None,
ray_address: Optional[str] = None,
bind_numa_node: Optional[bool] = None,
memory_allocator: Optional[str] = None,
_remove_output_root: bool = True,
**kwargs,
) -> Config:
"""
Load config from arguments and environment variables.
If not specified, use the default value.
"""
def get_env(key: str, type: type = str):
"""
Get an environment variable and convert it to the given type.
If the variable is not set, return None.
"""
value = os.environ.get(f"SP_{key}")
return type(value) if value is not None else None
platform = get_platform(get_env("PLATFORM") or platform)
job_id = get_env("JOBID") or job_id or platform.default_job_id()
job_time = (
get_env("JOB_TIME", datetime.fromisoformat)
or job_time
or platform.default_job_time()
)
data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root()
num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0
ray_address = get_env("RAY_ADDRESS") or ray_address
bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node
memory_allocator = (
get_env("MEMORY_ALLOCATOR")
or memory_allocator
or platform.default_memory_allocator()
)
config = Config(
job_id=job_id,
job_time=job_time,
data_root=data_root,
num_executors=num_executors,
ray_address=ray_address,
bind_numa_node=bind_numa_node,
memory_allocator=memory_allocator,
remove_output_root=_remove_output_root,
)
return config, platform

View File

@ -210,6 +210,7 @@ class TestFabric(unittest.TestCase):
self.sched_states = self.queue_manager.Queue()
scheduler = Scheduler(
ctx=runtime_ctx,
max_retry_count=max_retry_count,
max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry,