mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
stash
Signed-off-by: Runji Wang <runji@deepseek.com>
This commit is contained in:
parent
5a56a052bf
commit
9492a2872a
@ -13,22 +13,26 @@ import ray
|
||||
import ray.exceptions
|
||||
from loguru import logger
|
||||
|
||||
from smallpond.execution.manager import JobManager
|
||||
from smallpond.execution.task import Task
|
||||
from smallpond.io.filesystem import remove_path
|
||||
from smallpond.logical.dataset import *
|
||||
from smallpond.logical.node import *
|
||||
from smallpond.logical.optimizer import Optimizer
|
||||
from smallpond.logical.planner import Planner
|
||||
from smallpond.session import SessionBase
|
||||
|
||||
|
||||
class Session(SessionBase):
|
||||
class Session:
|
||||
# Extended session class with additional methods to create DataFrames.
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
"""
|
||||
Create a smallpond environment.
|
||||
"""
|
||||
self._job_manager = JobManager(**kwargs)
|
||||
self._ctx = Context()
|
||||
self._runtime_ctx = self._job_manager.runtime_ctx
|
||||
self._nodes: List[Node] = []
|
||||
|
||||
self._node_to_tasks: Dict[Node, List[Task]] = {}
|
||||
"""
|
||||
When a DataFrame is evaluated, the tasks of the logical plan are stored here.
|
||||
@ -158,52 +162,6 @@ class Session(SessionBase):
|
||||
return
|
||||
self._shutdown_called = True
|
||||
|
||||
# log status
|
||||
finished = self._all_tasks_finished()
|
||||
with open(self._runtime_ctx.job_status_path, "a") as fout:
|
||||
status = "success" if finished else "failure"
|
||||
fout.write(f"{status}@{datetime.now():%Y-%m-%d-%H-%M-%S}\n")
|
||||
|
||||
# clean up runtime directories if success
|
||||
if finished:
|
||||
logger.info("all tasks are finished, cleaning up")
|
||||
self._runtime_ctx.cleanup(remove_output_root=self.config.remove_output_root)
|
||||
else:
|
||||
logger.warning("tasks are not finished!")
|
||||
|
||||
super().shutdown()
|
||||
|
||||
def _summarize_task(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Return the total number of tasks and the number of tasks that are finished.
|
||||
"""
|
||||
dataset_refs = [
|
||||
task._dataset_ref
|
||||
for tasks in self._node_to_tasks.values()
|
||||
for task in tasks
|
||||
if task._dataset_ref is not None
|
||||
]
|
||||
ready_tasks, _ = ray.wait(
|
||||
dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False
|
||||
)
|
||||
return len(dataset_refs), len(ready_tasks)
|
||||
|
||||
def _all_tasks_finished(self) -> bool:
|
||||
"""
|
||||
Check if all tasks are finished.
|
||||
"""
|
||||
dataset_refs = [
|
||||
task._dataset_ref
|
||||
for tasks in self._node_to_tasks.values()
|
||||
for task in tasks
|
||||
]
|
||||
try:
|
||||
ray.get(dataset_refs, timeout=0)
|
||||
except Exception:
|
||||
# GetTimeoutError is raised if any task is not finished
|
||||
# RuntimeError is raised if any task failed
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DataFrame:
|
||||
@ -216,7 +174,7 @@ class DataFrame:
|
||||
def __init__(self, session: Session, plan: Node, recompute: bool = False):
|
||||
self.session = session
|
||||
self.plan = plan
|
||||
self.optimized_plan: Optional[Node] = None
|
||||
# self.optimized_plan: Optional[Node] = None
|
||||
self.need_recompute = recompute
|
||||
"""Whether to recompute the data regardless of whether it's already computed."""
|
||||
|
||||
@ -229,17 +187,17 @@ class DataFrame:
|
||||
"""
|
||||
Get or create tasks to compute the data.
|
||||
"""
|
||||
# optimize the plan
|
||||
if self.optimized_plan is None:
|
||||
logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
|
||||
self.optimized_plan = Optimizer(
|
||||
exclude_nodes=set(self.session._node_to_tasks.keys())
|
||||
).visit(self.plan)
|
||||
logger.info(
|
||||
f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
|
||||
)
|
||||
# # optimize the plan
|
||||
# if self.optimized_plan is None:
|
||||
# logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
|
||||
# self.optimized_plan = Optimizer(
|
||||
# exclude_nodes=set(self.session._node_to_tasks.keys())
|
||||
# ).visit(self.plan)
|
||||
# logger.info(
|
||||
# f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
|
||||
# )
|
||||
# return the tasks if already created
|
||||
if tasks := self.session._node_to_tasks.get(self.optimized_plan):
|
||||
if tasks := self.session._node_to_tasks.get(self.plan):
|
||||
return tasks
|
||||
|
||||
# remove all completed task files if recompute is needed
|
||||
@ -247,16 +205,16 @@ class DataFrame:
|
||||
remove_path(
|
||||
os.path.join(
|
||||
self.session._runtime_ctx.completed_task_dir,
|
||||
str(self.optimized_plan.id),
|
||||
str(self.plan.id),
|
||||
)
|
||||
)
|
||||
logger.info(f"cleared all results of {self.optimized_plan!r}")
|
||||
logger.info(f"cleared all results of {self.plan!r}")
|
||||
|
||||
# create tasks for the optimized plan
|
||||
planner = Planner(self.session._runtime_ctx)
|
||||
# let planner update self.session._node_to_tasks
|
||||
planner.node_to_tasks = self.session._node_to_tasks
|
||||
return planner.visit(self.optimized_plan)
|
||||
return planner.visit(self.plan)
|
||||
|
||||
def is_computed(self) -> bool:
|
||||
"""
|
||||
|
103
smallpond/execution/_ray.py
Normal file
103
smallpond/execution/_ray.py
Normal file
@ -0,0 +1,103 @@
|
||||
import copy
|
||||
import os
|
||||
import ray
|
||||
from loguru import logger
|
||||
|
||||
from smallpond.common import DEFAULT_MAX_RETRY_COUNT
|
||||
from smallpond.execution.task import Task
|
||||
from smallpond.execution.workqueue import WorkStatus
|
||||
from smallpond.io.filesystem import dump, load
|
||||
from smallpond.logical.dataset import DataSet
|
||||
|
||||
|
||||
def run_on_ray(task: Task) -> ray.ObjectRef:
|
||||
"""
|
||||
Run the task on Ray.
|
||||
Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset.
|
||||
A `_dataset_ref` attribute is added to the task to store the reference.
|
||||
"""
|
||||
if task._dataset_ref is not None:
|
||||
# already started
|
||||
return task._dataset_ref
|
||||
|
||||
# read the output dataset if the task has already finished
|
||||
if os.path.exists(task.ray_dataset_path):
|
||||
logger.info(f"task {task.key} already finished, skipping")
|
||||
output = load(task.ray_dataset_path)
|
||||
task._dataset_ref = ray.put(output)
|
||||
return task._dataset_ref
|
||||
|
||||
task = copy.copy(task)
|
||||
task.input_deps = {dep_key: None for dep_key in task.input_deps}
|
||||
|
||||
@ray.remote
|
||||
def exec_task(task: Task, *inputs: DataSet) -> DataSet:
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# ray use a process pool to execute tasks
|
||||
# we set the current process name to the task name
|
||||
# so that we can see task name in the logs
|
||||
mp.current_process().name = task.key
|
||||
|
||||
# probe the retry count
|
||||
task.retry_count = 0
|
||||
while os.path.exists(task.ray_marker_path):
|
||||
task.retry_count += 1
|
||||
if task.retry_count > DEFAULT_MAX_RETRY_COUNT:
|
||||
raise RuntimeError(
|
||||
f"task {task.key} failed after {task.retry_count} retries"
|
||||
)
|
||||
if task.retry_count > 0:
|
||||
logger.warning(
|
||||
f"task {task.key} is being retried for the {task.retry_count}th time"
|
||||
)
|
||||
# create the marker file
|
||||
Path(task.ray_marker_path).touch()
|
||||
|
||||
# put the inputs into the task
|
||||
assert len(inputs) == len(task.input_deps)
|
||||
task.input_datasets = list(inputs)
|
||||
# execute the task
|
||||
status = task.exec()
|
||||
if status != WorkStatus.SUCCEED:
|
||||
raise task.exception or RuntimeError(
|
||||
f"task {task.key} failed with status {status}"
|
||||
)
|
||||
|
||||
# dump the output dataset atomically
|
||||
os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True)
|
||||
dump(task.output, task.ray_dataset_path, atomic_write=True)
|
||||
return task.output
|
||||
|
||||
# this shows as {"name": ...} in timeline
|
||||
exec_task._function_name = repr(task)
|
||||
|
||||
remote_function = exec_task.options(
|
||||
# ray task name
|
||||
# do not include task id so that they can be grouped by node in ray dashboard
|
||||
name=f"{task.node_id}.{task.__class__.__name__}",
|
||||
num_cpus=task.cpu_limit,
|
||||
num_gpus=task.gpu_limit,
|
||||
memory=int(task.memory_limit),
|
||||
# note: `exec_on_scheduler` is ignored here,
|
||||
# because dataset is distributed on ray
|
||||
)
|
||||
try:
|
||||
task._dataset_ref = remote_function.remote(
|
||||
task, *[run_on_ray(dep) for dep in task.input_deps.values()]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if (
|
||||
"SimpleQueue objects should only be shared between processes through inheritance"
|
||||
in str(e)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n"
|
||||
f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
|
||||
) from e
|
||||
raise e
|
||||
return task._dataset_ref
|
@ -212,14 +212,10 @@ class JobManager(object):
|
||||
sched_state_observers.insert(0, sched_state_exporter)
|
||||
|
||||
if os.path.exists(self.runtime_ctx.sched_state_path):
|
||||
logger.warning(
|
||||
f"loading scheduler state from: {self.runtime_ctx.sched_state_path}"
|
||||
)
|
||||
self.scheduler: Scheduler = load(self.runtime_ctx.sched_state_path)
|
||||
self.scheduler.sched_epoch += 1
|
||||
self.scheduler.sched_state_observers = sched_state_observers
|
||||
self.scheduler = Scheduler.recover_from_file(self.runtime_ctx.sched_state_path, sched_state_observers)
|
||||
else:
|
||||
self.scheduler = Scheduler(
|
||||
ctx=self.runtime_ctx,
|
||||
max_retry_count=max_retry_count,
|
||||
max_fail_count=max_fail_count,
|
||||
prioritize_retry=prioritize_retry,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import cProfile
|
||||
from datetime import datetime
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
@ -12,7 +13,18 @@ from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@ -40,7 +52,7 @@ from smallpond.execution.workqueue import (
|
||||
WorkQueueInMemory,
|
||||
WorkQueueOnFilesystem,
|
||||
)
|
||||
from smallpond.io.filesystem import dump, remove_path
|
||||
from smallpond.io.filesystem import dump, load, remove_path
|
||||
from smallpond.logical.node import LogicalPlan, Node
|
||||
from smallpond.utility import cprofile_to_string
|
||||
|
||||
@ -359,10 +371,9 @@ class Scheduler(object):
|
||||
"""
|
||||
|
||||
large_num_nontrivial_tasks = 200 if pytest_running() else 20000
|
||||
StateCallback = Callable[["Scheduler"], Any]
|
||||
|
||||
class StateObserver(object):
|
||||
def __init__(self, callback: "Scheduler.StateCallback" = None) -> None:
|
||||
def __init__(self, callback: Callable[["Scheduler"], Any] = None) -> None:
|
||||
assert callback is None or isinstance(callback, Callable)
|
||||
self.enabled = True
|
||||
self.callback = callback
|
||||
@ -380,6 +391,7 @@ class Scheduler(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: RuntimeContext,
|
||||
*,
|
||||
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
|
||||
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
|
||||
@ -390,18 +402,46 @@ class Scheduler(object):
|
||||
remove_output_root: bool = False,
|
||||
sched_state_observers: Optional[List[StateObserver]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the scheduler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx: RuntimeContext
|
||||
The runtime context.
|
||||
max_retry_count: int, optional
|
||||
The maximum retry count. Default to 5.
|
||||
max_fail_count: int, optional
|
||||
The maximum fail count. Default to 3.
|
||||
prioritize_retry: bool, optional
|
||||
Whether to prioritize retry. Default to False.
|
||||
speculative_exec: Literal["disable", "enable", "aggressive"], optional
|
||||
If "enable", long-running tasks will be rescheduled on other executors.
|
||||
If "aggressive", it will be more aggressive to reschedule long-running tasks.
|
||||
If "disable", no speculative execution will be performed.
|
||||
Default to "enable".
|
||||
stop_executor_on_failure: bool, optional
|
||||
Whether to stop the executor on failure. Default to False.
|
||||
nonzero_exitcode_as_oom: bool, optional
|
||||
Whether to treat non-zero exit code as out-of-memory. Default to False.
|
||||
remove_output_root: bool, optional
|
||||
Whether to remove the output root on exit. Default to False.
|
||||
sched_state_observers: Optional[List[StateObserver]], optional
|
||||
The state observers.
|
||||
"""
|
||||
# configs
|
||||
self.ctx = ctx
|
||||
self.max_retry_count = max_retry_count
|
||||
self.max_fail_count = max_fail_count
|
||||
self.standalone_mode = self.ctx.num_executors == 0
|
||||
self.prioritize_retry = prioritize_retry
|
||||
self.disable_speculative_exec = speculative_exec == "disable"
|
||||
self.aggressive_speculative_exec = speculative_exec == "aggressive"
|
||||
self.speculative_exec: Literal["disable", "enable", "aggressive"] = (
|
||||
speculative_exec
|
||||
)
|
||||
self.stop_executor_on_failure = stop_executor_on_failure
|
||||
self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom
|
||||
self.remove_output_root = remove_output_root
|
||||
self.sched_state_observers: List[Scheduler.StateObserver] = (
|
||||
sched_state_observers or []
|
||||
)
|
||||
self.sched_state_observers = sched_state_observers or []
|
||||
self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2
|
||||
# task states
|
||||
self.local_queue: List[Task] = []
|
||||
@ -416,30 +456,47 @@ class Scheduler(object):
|
||||
self.local_executor = LocalExecutor.create(self.ctx, "localhost")
|
||||
self.available_executors = {self.local_executor.id: self.local_executor}
|
||||
# other runtime states
|
||||
self.sched_running = False
|
||||
self.sched_start_time = 0
|
||||
self.failure = False
|
||||
self.sched_start_time = time.time()
|
||||
self.last_executor_probe_time = 0
|
||||
self.last_state_notify_time = 0
|
||||
self.probe_epoch = 0
|
||||
self.sched_epoch = 0
|
||||
|
||||
self._post_init()
|
||||
|
||||
@staticmethod
|
||||
def recover_from_file(
|
||||
sched_state_path: str, sched_state_observers: List[StateObserver]
|
||||
) -> "Scheduler":
|
||||
"""
|
||||
Recover the scheduler from the previous run.
|
||||
"""
|
||||
logger.warning(f"loading scheduler state from: {sched_state_path}")
|
||||
self: Scheduler = load(sched_state_path)
|
||||
|
||||
self.sched_epoch += 1
|
||||
self.prioritize_retry = True
|
||||
# observers are not pickled, so we need to re-add them
|
||||
self.sched_state_observers = sched_state_observers
|
||||
|
||||
self._post_init()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state["sched_state_observers"]
|
||||
del state["profiler"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.sched_state_observers = []
|
||||
self.profiler = None
|
||||
|
||||
@property
|
||||
def elapsed_time(self):
|
||||
return time.time() - self.sched_start_time
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self.exec_plan.root_task.key in self.succeeded_tasks
|
||||
|
||||
@property
|
||||
def progress(self) -> Tuple[int, int, float]:
|
||||
num_succeeded = len(self.succeeded_nontrivial_tasks)
|
||||
@ -582,7 +639,7 @@ class Scheduler(object):
|
||||
for executor in self.working_executors:
|
||||
for idx, item in enumerate(executor.running_works.values()):
|
||||
aggressive_retry = (
|
||||
self.aggressive_speculative_exec
|
||||
self.speculative_exec == "aggressive"
|
||||
and len(self.good_executors) >= self.ctx.num_executors
|
||||
)
|
||||
short_sched_queue = len(self.sched_queue) < len(self.good_executors)
|
||||
@ -644,7 +701,7 @@ class Scheduler(object):
|
||||
for executor in self.working_executors:
|
||||
executor.probe(self.probe_epoch)
|
||||
# start speculative execution of tasks
|
||||
if not self.disable_speculative_exec:
|
||||
if self.speculative_exec != "disable":
|
||||
self.start_speculative_execution()
|
||||
|
||||
def update_executor_states(self):
|
||||
@ -777,7 +834,7 @@ class Scheduler(object):
|
||||
return task
|
||||
|
||||
@logger.catch(reraise=pytest_running(), message="failed to clean temp files")
|
||||
def clean_temp_files(self, pool: ThreadPoolExecutor):
|
||||
def clean_temp_files(self):
|
||||
remove_path(self.ctx.queue_root)
|
||||
remove_path(self.ctx.temp_root)
|
||||
remove_path(self.ctx.staging_root)
|
||||
@ -786,7 +843,10 @@ class Scheduler(object):
|
||||
logger.info(
|
||||
f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..."
|
||||
)
|
||||
assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks))
|
||||
with ThreadPoolExecutor(32) as pool:
|
||||
assert list(
|
||||
pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)
|
||||
)
|
||||
|
||||
@logger.catch(reraise=pytest_running(), message="failed to export task metrics")
|
||||
def export_task_metrics(self):
|
||||
@ -963,29 +1023,25 @@ class Scheduler(object):
|
||||
|
||||
def log_current_status(self):
|
||||
with open(self.ctx.job_status_path, "w") as fout:
|
||||
if self.sched_running:
|
||||
status = "running"
|
||||
elif self.success:
|
||||
status = "success"
|
||||
else:
|
||||
if self.failure:
|
||||
status = "failure"
|
||||
fout.write(f"{status}@{int(time.time())}")
|
||||
else:
|
||||
status = "running"
|
||||
fout.write(f"{status}@{datetime.now().isoformat()}")
|
||||
|
||||
def run(self, exec_plan: ExecutionPlan) -> bool:
|
||||
def _post_init(self):
|
||||
"""
|
||||
Run the execution plan.
|
||||
Common initialization after startup or recovery.
|
||||
"""
|
||||
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
|
||||
logger.info(
|
||||
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"
|
||||
)
|
||||
|
||||
perf_profile = None
|
||||
if self.ctx.enable_profiling:
|
||||
perf_profile = cProfile.Profile()
|
||||
perf_profile.enable()
|
||||
self.profiler = cProfile.Profile()
|
||||
self.profiler.enable()
|
||||
|
||||
with ThreadPoolExecutor(32) as pool:
|
||||
self.sched_running = True
|
||||
self.sched_start_time = time.time()
|
||||
self.last_executor_probe_time = 0
|
||||
@ -1005,13 +1061,7 @@ class Scheduler(object):
|
||||
)
|
||||
self.try_enqueue(pending_tasks)
|
||||
|
||||
if self.sched_epoch == 0:
|
||||
leaf_tasks = self.exec_plan.leaves
|
||||
logger.info(
|
||||
f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..."
|
||||
)
|
||||
self.try_enqueue(leaf_tasks)
|
||||
|
||||
with ThreadPoolExecutor(32) as pool:
|
||||
self.log_overall_progress()
|
||||
while (num_finished_tasks := self.process_finished_tasks(pool)) > 0:
|
||||
logger.info(
|
||||
@ -1029,44 +1079,54 @@ class Scheduler(object):
|
||||
self.try_enqueue(earlier_running_tasks)
|
||||
|
||||
self.suspend_good_executors()
|
||||
self.add_state_observer(
|
||||
Scheduler.StateObserver(Scheduler.log_current_status)
|
||||
)
|
||||
self.add_state_observer(
|
||||
Scheduler.StateObserver(Scheduler.export_timeline_figs)
|
||||
)
|
||||
self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status))
|
||||
self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs))
|
||||
self.notify_state_observers(force_notify=True)
|
||||
|
||||
def run(self, exec_plan: ExecutionPlan) -> bool:
|
||||
"""
|
||||
Run the execution plan.
|
||||
"""
|
||||
leaf_tasks = exec_plan.leaves
|
||||
logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...")
|
||||
self.try_enqueue(leaf_tasks)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(32) as pool:
|
||||
self.local_executor.start(pool)
|
||||
self.sched_loop(pool)
|
||||
self.sched_loop(pool, exec_plan.root_task)
|
||||
finally:
|
||||
logger.info(f"schedule loop stopped")
|
||||
self.sched_running = False
|
||||
self.notify_state_observers(force_notify=True)
|
||||
|
||||
logger.success(f"final output path: {exec_plan.final_output_path}")
|
||||
logger.info(
|
||||
f"analyzed plan:{os.linesep}{exec_plan.analyzed_logical_plan.explain_str()}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def cleanup(self):
|
||||
self.export_task_metrics()
|
||||
self.stop_executors()
|
||||
|
||||
# if --output_path is specified, remove the output root as well
|
||||
if self.remove_output_root or self.ctx.final_output_path:
|
||||
if self.remove_output_root:
|
||||
remove_path(self.ctx.staging_root)
|
||||
remove_path(self.ctx.output_root)
|
||||
|
||||
if self.success:
|
||||
self.clean_temp_files(pool)
|
||||
logger.success(f"final output path: {self.exec_plan.final_output_path}")
|
||||
logger.info(
|
||||
f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}"
|
||||
)
|
||||
if not self.failure:
|
||||
self.clean_temp_files()
|
||||
if self.ctx.final_output_path:
|
||||
remove_path(self.ctx.final_output_path)
|
||||
|
||||
if perf_profile is not None:
|
||||
if self.profiler is not None:
|
||||
logger.debug(
|
||||
f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}"
|
||||
f"scheduler perf profile:{os.linesep}{cprofile_to_string(self.profiler)}"
|
||||
)
|
||||
|
||||
logger.info(f"scheduler of job {self.ctx.job_id} exits")
|
||||
logger.complete()
|
||||
return self.success
|
||||
|
||||
def try_enqueue(self, tasks: Union[Iterable[Task], Task]):
|
||||
tasks = tasks if isinstance(tasks, Iterable) else [tasks]
|
||||
@ -1092,15 +1152,15 @@ class Scheduler(object):
|
||||
else:
|
||||
self.sched_queue.append(task)
|
||||
|
||||
def sched_loop(self, pool: ThreadPoolExecutor) -> bool:
|
||||
def sched_loop(self, pool: ThreadPoolExecutor, task: Task):
|
||||
"""
|
||||
Run the scheduler loop until the task is finished or failed.
|
||||
"""
|
||||
|
||||
has_progress = True
|
||||
do_notify = False
|
||||
|
||||
if self.success:
|
||||
logger.success(f"job already succeeded, stopping scheduler ...")
|
||||
return True
|
||||
|
||||
while self.sched_running:
|
||||
while not self.failure and self.tasks[task.key].status == WorkStatus.INCOMPLETE:
|
||||
self.probe_executors()
|
||||
self.update_executor_states()
|
||||
|
||||
@ -1153,9 +1213,6 @@ class Scheduler(object):
|
||||
|
||||
has_progress |= self.process_finished_tasks(pool) > 0
|
||||
|
||||
# out of loop
|
||||
return self.success
|
||||
|
||||
def dispatch_tasks(self, pool: ThreadPoolExecutor):
|
||||
# sort pending tasks
|
||||
item_sort_key = (
|
||||
@ -1240,6 +1297,10 @@ class Scheduler(object):
|
||||
return num_dispatched_tasks
|
||||
|
||||
def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int:
|
||||
"""
|
||||
Process finished tasks from all executors.
|
||||
Return the number of finished tasks.
|
||||
"""
|
||||
pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values())
|
||||
num_finished_tasks = 0
|
||||
|
||||
@ -1309,7 +1370,7 @@ class Scheduler(object):
|
||||
f"task failed too many times: {finished_task}, stopping ..."
|
||||
)
|
||||
self.stop_executors()
|
||||
self.sched_running = False
|
||||
self.failure = True
|
||||
|
||||
if not executor.local and finished_task.oom(
|
||||
self.nonzero_exitcode_as_oom
|
||||
|
@ -42,7 +42,6 @@ import pandas as pd
|
||||
import psutil
|
||||
import pyarrow as arrow
|
||||
import pyarrow.parquet as parquet
|
||||
import ray
|
||||
from loguru import logger
|
||||
|
||||
from smallpond.common import (
|
||||
@ -642,10 +641,6 @@ class Task(WorkItem):
|
||||
# implementor can use this variable as a checkpoint and restore from it after interrupted
|
||||
self.runtime_state = None
|
||||
|
||||
# if the task is executed by ray, this is the reference to the output dataset
|
||||
# do not use this variable directly, use `self.run_on_ray()` instead
|
||||
self._dataset_ref: Optional[ray.ObjectRef] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.key}.{self.sched_epoch}.{self.retry_count},{self.node_id}"
|
||||
|
||||
@ -1104,97 +1099,6 @@ class Task(WorkItem):
|
||||
for name, value in metrics.items():
|
||||
self.perf_metrics[name] += value
|
||||
|
||||
def run_on_ray(self) -> ray.ObjectRef:
|
||||
"""
|
||||
Run the task on Ray.
|
||||
Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset.
|
||||
"""
|
||||
if self._dataset_ref is not None:
|
||||
# already started
|
||||
return self._dataset_ref
|
||||
|
||||
# read the output dataset if the task has already finished
|
||||
if os.path.exists(self.ray_dataset_path):
|
||||
logger.info(f"task {self.key} already finished, skipping")
|
||||
output = load(self.ray_dataset_path)
|
||||
self._dataset_ref = ray.put(output)
|
||||
return self._dataset_ref
|
||||
|
||||
task = copy.copy(self)
|
||||
task.input_deps = {dep_key: None for dep_key in task.input_deps}
|
||||
|
||||
@ray.remote
|
||||
def exec_task(task: Task, *inputs: DataSet) -> DataSet:
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# ray use a process pool to execute tasks
|
||||
# we set the current process name to the task name
|
||||
# so that we can see task name in the logs
|
||||
mp.current_process().name = task.key
|
||||
|
||||
# probe the retry count
|
||||
task.retry_count = 0
|
||||
while os.path.exists(task.ray_marker_path):
|
||||
task.retry_count += 1
|
||||
if task.retry_count > DEFAULT_MAX_RETRY_COUNT:
|
||||
raise RuntimeError(
|
||||
f"task {task.key} failed after {task.retry_count} retries"
|
||||
)
|
||||
if task.retry_count > 0:
|
||||
logger.warning(
|
||||
f"task {task.key} is being retried for the {task.retry_count}th time"
|
||||
)
|
||||
# create the marker file
|
||||
Path(task.ray_marker_path).touch()
|
||||
|
||||
# put the inputs into the task
|
||||
assert len(inputs) == len(task.input_deps)
|
||||
task.input_datasets = list(inputs)
|
||||
# execute the task
|
||||
status = task.exec()
|
||||
if status != WorkStatus.SUCCEED:
|
||||
raise task.exception or RuntimeError(
|
||||
f"task {task.key} failed with status {status}"
|
||||
)
|
||||
|
||||
# dump the output dataset atomically
|
||||
os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True)
|
||||
dump(task.output, task.ray_dataset_path, atomic_write=True)
|
||||
return task.output
|
||||
|
||||
# this shows as {"name": ...} in timeline
|
||||
exec_task._function_name = repr(task)
|
||||
|
||||
remote_function = exec_task.options(
|
||||
# ray task name
|
||||
# do not include task id so that they can be grouped by node in ray dashboard
|
||||
name=f"{task.node_id}.{self.__class__.__name__}",
|
||||
num_cpus=self.cpu_limit,
|
||||
num_gpus=self.gpu_limit,
|
||||
memory=int(self.memory_limit),
|
||||
# note: `exec_on_scheduler` is ignored here,
|
||||
# because dataset is distributed on ray
|
||||
)
|
||||
try:
|
||||
self._dataset_ref = remote_function.remote(
|
||||
task, *[dep.run_on_ray() for dep in self.input_deps.values()]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if (
|
||||
"SimpleQueue objects should only be shared between processes through inheritance"
|
||||
in str(e)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n"
|
||||
f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
|
||||
) from e
|
||||
raise e
|
||||
return self._dataset_ref
|
||||
|
||||
|
||||
class ExecSqlQueryMixin(Task):
|
||||
|
||||
|
@ -21,368 +21,7 @@ import graphviz.backend.execute
|
||||
from loguru import logger
|
||||
|
||||
import smallpond
|
||||
from smallpond.execution.manager import JobManager
|
||||
from smallpond.execution.task import JobId, RuntimeContext
|
||||
from smallpond.logical.node import Context
|
||||
from smallpond.platform import Platform, get_platform
|
||||
|
||||
|
||||
class SessionBase:
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Create a smallpond environment.
|
||||
"""
|
||||
super().__init__()
|
||||
self._ctx = Context()
|
||||
self.config, self._platform = Config.from_args_and_env(**kwargs)
|
||||
|
||||
# construct runtime context for Tasks
|
||||
runtime_ctx = RuntimeContext(
|
||||
job_id=JobId(hex=self.config.job_id),
|
||||
job_time=self.config.job_time,
|
||||
data_root=self.config.data_root,
|
||||
num_executors=self.config.num_executors,
|
||||
bind_numa_node=self.config.bind_numa_node,
|
||||
shared_log_root=self._platform.shared_log_root(),
|
||||
)
|
||||
self._runtime_ctx = runtime_ctx
|
||||
|
||||
# if `spawn` is specified, spawn a job and exit
|
||||
if os.environ.get("SP_SPAWN") == "1":
|
||||
self._spawn_self()
|
||||
exit(0)
|
||||
|
||||
self._runtime_ctx.initialize(exec_id=socket.gethostname())
|
||||
logger.info(f"using platform: {self._platform}")
|
||||
logger.info(f"command-line arguments: {' '.join(sys.argv)}")
|
||||
logger.info(f"session config: {self.config}")
|
||||
|
||||
def setup_worker():
|
||||
runtime_ctx._init_logs(
|
||||
exec_id=socket.gethostname(), capture_stdout_stderr=True
|
||||
)
|
||||
|
||||
if self.config.ray_address is None:
|
||||
# find the memory allocator
|
||||
if self.config.memory_allocator == "system":
|
||||
malloc_path = ""
|
||||
elif self.config.memory_allocator == "jemalloc":
|
||||
malloc_path = shutil.which("libjemalloc.so.2")
|
||||
assert malloc_path is not None, "jemalloc is not installed"
|
||||
elif self.config.memory_allocator == "mimalloc":
|
||||
malloc_path = shutil.which("libmimalloc.so.2.1")
|
||||
assert malloc_path is not None, "mimalloc is not installed"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported memory allocator: {self.config.memory_allocator}"
|
||||
)
|
||||
memory_purge_delay = 10000
|
||||
|
||||
# start ray head node
|
||||
# for ray head node to access grafana
|
||||
os.environ["RAY_GRAFANA_HOST"] = "http://localhost:8122"
|
||||
self._ray_address = ray.init(
|
||||
# start a new local cluster
|
||||
address="local",
|
||||
# disable local CPU resource if not running on localhost
|
||||
num_cpus=(
|
||||
0
|
||||
if self.config.num_executors > 0
|
||||
else self._runtime_ctx.usable_cpu_count
|
||||
),
|
||||
# set the memory limit to the available memory size
|
||||
_memory=self._runtime_ctx.usable_memory_size,
|
||||
# setup logging for workers
|
||||
log_to_driver=False,
|
||||
runtime_env={
|
||||
"worker_process_setup_hook": setup_worker,
|
||||
"env_vars": {
|
||||
"LD_PRELOAD": malloc_path,
|
||||
"MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16",
|
||||
"MIMALLOC_PURGE_DELAY": f"{memory_purge_delay}",
|
||||
"ARROW_DEFAULT_MEMORY_POOL": self.config.memory_allocator,
|
||||
"ARROW_IO_THREADS": "2",
|
||||
"OMP_NUM_THREADS": "2",
|
||||
"POLARS_MAX_THREADS": "2",
|
||||
"NUMEXPR_MAX_THREADS": "2",
|
||||
"RAY_PROFILING": "1",
|
||||
},
|
||||
},
|
||||
dashboard_host="0.0.0.0",
|
||||
dashboard_port=8008,
|
||||
# for prometheus to scrape metrics
|
||||
_metrics_export_port=8080,
|
||||
).address_info["gcs_address"]
|
||||
logger.info(f"started ray cluster at {self._ray_address}")
|
||||
|
||||
self._prometheus_process = self._start_prometheus()
|
||||
self._grafana_process = self._start_grafana()
|
||||
else:
|
||||
self._ray_address = self.config.ray_address
|
||||
self._prometheus_process = None
|
||||
self._grafana_process = None
|
||||
logger.info(f"connected to ray cluster at {self._ray_address}")
|
||||
|
||||
# start workers
|
||||
if self.config.num_executors > 0:
|
||||
# override configs
|
||||
kwargs["job_id"] = self.config.job_id
|
||||
|
||||
self._job_names = self._platform.start_job(
|
||||
self.config.num_executors,
|
||||
entrypoint=os.path.join(os.path.dirname(__file__), "worker.py"),
|
||||
args=[
|
||||
f"--ray_address={self._ray_address}",
|
||||
f"--log_dir={self._runtime_ctx.log_root}",
|
||||
*(["--bind_numa_node"] if self.config.bind_numa_node else []),
|
||||
],
|
||||
extra_opts=kwargs,
|
||||
)
|
||||
else:
|
||||
self._job_names = []
|
||||
|
||||
# spawn a thread to periodically dump metrics
|
||||
self._stop_event = threading.Event()
|
||||
self._dump_thread = threading.Thread(
|
||||
name="dump_thread", target=self._dump_periodically, daemon=True
|
||||
)
|
||||
self._dump_thread.start()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the session.
|
||||
"""
|
||||
logger.info("shutting down session")
|
||||
self._stop_event.set()
|
||||
|
||||
# stop all jobs
|
||||
for job_name in self._job_names:
|
||||
self._platform.stop_job(job_name)
|
||||
self._job_names = []
|
||||
|
||||
self._dump_thread.join()
|
||||
if self.config.ray_address is None:
|
||||
ray.shutdown()
|
||||
if self._prometheus_process is not None:
|
||||
self._prometheus_process.terminate()
|
||||
self._prometheus_process.wait()
|
||||
self._prometheus_process = None
|
||||
logger.info("stopped prometheus")
|
||||
if self._grafana_process is not None:
|
||||
self._grafana_process.terminate()
|
||||
self._grafana_process.wait()
|
||||
self._grafana_process = None
|
||||
logger.info("stopped grafana")
|
||||
|
||||
def _spawn_self(self):
|
||||
"""
|
||||
Spawn a new job to run the current script.
|
||||
"""
|
||||
self._platform.start_job(
|
||||
num_nodes=1,
|
||||
entrypoint=sys.argv[0],
|
||||
args=sys.argv[1:],
|
||||
extra_opts=dict(
|
||||
tags=["smallpond", "scheduler", smallpond.__version__],
|
||||
),
|
||||
envs={
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if k.startswith("SP_") and k != "SP_SPAWN"
|
||||
},
|
||||
)
|
||||
|
||||
def _start_prometheus(self) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Start prometheus server if it exists.
|
||||
"""
|
||||
prometheus_path = shutil.which("prometheus")
|
||||
if prometheus_path is None:
|
||||
logger.warning("prometheus is not found")
|
||||
return None
|
||||
os.makedirs(f"{self._runtime_ctx.log_root}/prometheus", exist_ok=True)
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
prometheus_path,
|
||||
"--config.file=/tmp/ray/session_latest/metrics/prometheus/prometheus.yml",
|
||||
f"--storage.tsdb.path={self._runtime_ctx.log_root}/prometheus/data",
|
||||
],
|
||||
stderr=open(f"{self._runtime_ctx.log_root}/prometheus/prometheus.log", "w"),
|
||||
)
|
||||
logger.info("started prometheus")
|
||||
return proc
|
||||
|
||||
def _start_grafana(self) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Start grafana server if it exists.
|
||||
"""
|
||||
homepath = self._platform.grafana_homepath()
|
||||
if homepath is None:
|
||||
logger.warning("grafana is not found")
|
||||
return None
|
||||
os.makedirs(f"{self._runtime_ctx.log_root}/grafana", exist_ok=True)
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
shutil.which("grafana"),
|
||||
"server",
|
||||
"--config",
|
||||
"/tmp/ray/session_latest/metrics/grafana/grafana.ini",
|
||||
"-homepath",
|
||||
homepath,
|
||||
"web",
|
||||
],
|
||||
stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"),
|
||||
env={
|
||||
"GF_SERVER_HTTP_PORT": "8122", # redirect to an available port
|
||||
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST")
|
||||
or "http://localhost:8122",
|
||||
"GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data",
|
||||
},
|
||||
)
|
||||
logger.info(f"started grafana at http://localhost:8122")
|
||||
return proc
|
||||
|
||||
@property
|
||||
def runtime_ctx(self) -> RuntimeContext:
|
||||
return self._runtime_ctx
|
||||
|
||||
def graph(self) -> Digraph:
|
||||
"""
|
||||
Get the logical plan graph.
|
||||
"""
|
||||
# implemented in Session class
|
||||
raise NotImplementedError("graph")
|
||||
|
||||
def dump_graph(self, path: Optional[str] = None):
|
||||
"""
|
||||
Dump the logical plan graph to a file.
|
||||
"""
|
||||
path = path or os.path.join(self.runtime_ctx.log_root, "graph")
|
||||
try:
|
||||
self.graph().render(path, format="png")
|
||||
logger.debug(f"dumped graph to {path}")
|
||||
except graphviz.backend.execute.ExecutableNotFound as e:
|
||||
logger.warning(f"graphviz is not installed, skipping graph dump")
|
||||
|
||||
def dump_timeline(self, path: Optional[str] = None):
|
||||
"""
|
||||
Dump the task timeline to a file.
|
||||
"""
|
||||
path = path or os.path.join(self.runtime_ctx.log_root, "timeline")
|
||||
# the default timeline is grouped by worker
|
||||
exec_path = f"{path}_exec"
|
||||
ray.timeline(exec_path)
|
||||
logger.debug(f"dumped timeline to {exec_path}")
|
||||
|
||||
# generate another timeline grouped by node
|
||||
with open(exec_path) as f:
|
||||
records = json.load(f)
|
||||
new_records = []
|
||||
for record in records:
|
||||
# swap record name and pid-tid
|
||||
name = record["name"]
|
||||
try:
|
||||
node_id = name.split(",")[-1]
|
||||
task_id = name.split("-")[1].split(".")[0]
|
||||
task_name = name.split("-")[0]
|
||||
record["pid"] = f"{node_id}-{task_name}"
|
||||
record["tid"] = f"task {task_id}"
|
||||
new_records.append(record)
|
||||
except Exception:
|
||||
# filter out other records
|
||||
pass
|
||||
node_path = f"{path}_plan"
|
||||
with open(node_path, "w") as f:
|
||||
json.dump(new_records, f)
|
||||
logger.debug(f"dumped timeline to {node_path}")
|
||||
|
||||
def _summarize_task(self) -> Tuple[int, int]:
|
||||
# implemented in Session class
|
||||
raise NotImplementedError("summarize_task")
|
||||
|
||||
def _dump_periodically(self):
|
||||
"""
|
||||
Dump the graph and timeline every minute.
|
||||
Set `self._stop_event` to have a final dump and stop this thread.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
self._stop_event.wait(60)
|
||||
self.dump_graph()
|
||||
self.dump_timeline()
|
||||
num_total_tasks, num_finished_tasks = self._summarize_task()
|
||||
percent = (
|
||||
num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
|
||||
)
|
||||
logger.info(
|
||||
f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""
|
||||
Configuration for a session.
|
||||
"""
|
||||
|
||||
job_id: str # JOBID
|
||||
job_time: datetime # JOB_TIME
|
||||
data_root: str # DATA_ROOT
|
||||
num_executors: int # NUM_NODES_TOTAL
|
||||
ray_address: Optional[str] # RAY_ADDRESS
|
||||
bind_numa_node: bool # BIND_NUMA_NODE
|
||||
memory_allocator: str # MEMORY_ALLOCATOR
|
||||
remove_output_root: bool
|
||||
|
||||
@staticmethod
|
||||
def from_args_and_env(
|
||||
platform: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
job_time: Optional[datetime] = None,
|
||||
data_root: Optional[str] = None,
|
||||
num_executors: Optional[int] = None,
|
||||
ray_address: Optional[str] = None,
|
||||
bind_numa_node: Optional[bool] = None,
|
||||
memory_allocator: Optional[str] = None,
|
||||
_remove_output_root: bool = True,
|
||||
**kwargs,
|
||||
) -> Config:
|
||||
"""
|
||||
Load config from arguments and environment variables.
|
||||
If not specified, use the default value.
|
||||
"""
|
||||
|
||||
def get_env(key: str, type: type = str):
|
||||
"""
|
||||
Get an environment variable and convert it to the given type.
|
||||
If the variable is not set, return None.
|
||||
"""
|
||||
value = os.environ.get(f"SP_{key}")
|
||||
return type(value) if value is not None else None
|
||||
|
||||
platform = get_platform(get_env("PLATFORM") or platform)
|
||||
job_id = get_env("JOBID") or job_id or platform.default_job_id()
|
||||
job_time = (
|
||||
get_env("JOB_TIME", datetime.fromisoformat)
|
||||
or job_time
|
||||
or platform.default_job_time()
|
||||
)
|
||||
data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root()
|
||||
num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0
|
||||
ray_address = get_env("RAY_ADDRESS") or ray_address
|
||||
bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node
|
||||
memory_allocator = (
|
||||
get_env("MEMORY_ALLOCATOR")
|
||||
or memory_allocator
|
||||
or platform.default_memory_allocator()
|
||||
)
|
||||
|
||||
config = Config(
|
||||
job_id=job_id,
|
||||
job_time=job_time,
|
||||
data_root=data_root,
|
||||
num_executors=num_executors,
|
||||
ray_address=ray_address,
|
||||
bind_numa_node=bind_numa_node,
|
||||
memory_allocator=memory_allocator,
|
||||
remove_output_root=_remove_output_root,
|
||||
)
|
||||
return config, platform
|
||||
|
@ -210,6 +210,7 @@ class TestFabric(unittest.TestCase):
|
||||
self.sched_states = self.queue_manager.Queue()
|
||||
|
||||
scheduler = Scheduler(
|
||||
ctx=runtime_ctx,
|
||||
max_retry_count=max_retry_count,
|
||||
max_fail_count=max_fail_count,
|
||||
prioritize_retry=prioritize_retry,
|
||||
|
Loading…
Reference in New Issue
Block a user