diff --git a/smallpond/dataframe.py b/smallpond/dataframe.py index 4974910..9968cf1 100644 --- a/smallpond/dataframe.py +++ b/smallpond/dataframe.py @@ -13,22 +13,26 @@ import ray import ray.exceptions from loguru import logger +from smallpond.execution.manager import JobManager from smallpond.execution.task import Task from smallpond.io.filesystem import remove_path from smallpond.logical.dataset import * from smallpond.logical.node import * from smallpond.logical.optimizer import Optimizer from smallpond.logical.planner import Planner -from smallpond.session import SessionBase -class Session(SessionBase): +class Session: # Extended session class with additional methods to create DataFrames. def __init__(self, **kwargs): - super().__init__(**kwargs) + """ + Create a smallpond environment. + """ + self._job_manager = JobManager(**kwargs) + self._ctx = Context() + self._runtime_ctx = self._job_manager.runtime_ctx self._nodes: List[Node] = [] - self._node_to_tasks: Dict[Node, List[Task]] = {} """ When a DataFrame is evaluated, the tasks of the logical plan are stored here. @@ -158,52 +162,6 @@ class Session(SessionBase): return self._shutdown_called = True - # log status - finished = self._all_tasks_finished() - with open(self._runtime_ctx.job_status_path, "a") as fout: - status = "success" if finished else "failure" - fout.write(f"{status}@{datetime.now():%Y-%m-%d-%H-%M-%S}\n") - - # clean up runtime directories if success - if finished: - logger.info("all tasks are finished, cleaning up") - self._runtime_ctx.cleanup(remove_output_root=self.config.remove_output_root) - else: - logger.warning("tasks are not finished!") - - super().shutdown() - - def _summarize_task(self) -> Tuple[int, int]: - """ - Return the total number of tasks and the number of tasks that are finished. - """ - dataset_refs = [ - task._dataset_ref - for tasks in self._node_to_tasks.values() - for task in tasks - if task._dataset_ref is not None - ] - ready_tasks, _ = ray.wait( - dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False - ) - return len(dataset_refs), len(ready_tasks) - - def _all_tasks_finished(self) -> bool: - """ - Check if all tasks are finished. - """ - dataset_refs = [ - task._dataset_ref - for tasks in self._node_to_tasks.values() - for task in tasks - ] - try: - ray.get(dataset_refs, timeout=0) - except Exception: - # GetTimeoutError is raised if any task is not finished - # RuntimeError is raised if any task failed - return False - return True class DataFrame: @@ -216,7 +174,7 @@ class DataFrame: def __init__(self, session: Session, plan: Node, recompute: bool = False): self.session = session self.plan = plan - self.optimized_plan: Optional[Node] = None + # self.optimized_plan: Optional[Node] = None self.need_recompute = recompute """Whether to recompute the data regardless of whether it's already computed.""" @@ -229,17 +187,17 @@ class DataFrame: """ Get or create tasks to compute the data. """ - # optimize the plan - if self.optimized_plan is None: - logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}") - self.optimized_plan = Optimizer( - exclude_nodes=set(self.session._node_to_tasks.keys()) - ).visit(self.plan) - logger.info( - f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}" - ) + # # optimize the plan + # if self.optimized_plan is None: + # logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}") + # self.optimized_plan = Optimizer( + # exclude_nodes=set(self.session._node_to_tasks.keys()) + # ).visit(self.plan) + # logger.info( + # f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}" + # ) # return the tasks if already created - if tasks := self.session._node_to_tasks.get(self.optimized_plan): + if tasks := self.session._node_to_tasks.get(self.plan): return tasks # remove all completed task files if recompute is needed @@ -247,16 +205,16 @@ class DataFrame: remove_path( os.path.join( self.session._runtime_ctx.completed_task_dir, - str(self.optimized_plan.id), + str(self.plan.id), ) ) - logger.info(f"cleared all results of {self.optimized_plan!r}") + logger.info(f"cleared all results of {self.plan!r}") # create tasks for the optimized plan planner = Planner(self.session._runtime_ctx) # let planner update self.session._node_to_tasks planner.node_to_tasks = self.session._node_to_tasks - return planner.visit(self.optimized_plan) + return planner.visit(self.plan) def is_computed(self) -> bool: """ diff --git a/smallpond/execution/_ray.py b/smallpond/execution/_ray.py new file mode 100644 index 0000000..7a51276 --- /dev/null +++ b/smallpond/execution/_ray.py @@ -0,0 +1,103 @@ +import copy +import os +import ray +from loguru import logger + +from smallpond.common import DEFAULT_MAX_RETRY_COUNT +from smallpond.execution.task import Task +from smallpond.execution.workqueue import WorkStatus +from smallpond.io.filesystem import dump, load +from smallpond.logical.dataset import DataSet + + +def run_on_ray(task: Task) -> ray.ObjectRef: + """ + Run the task on Ray. + Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset. + A `_dataset_ref` attribute is added to the task to store the reference. + """ + if task._dataset_ref is not None: + # already started + return task._dataset_ref + + # read the output dataset if the task has already finished + if os.path.exists(task.ray_dataset_path): + logger.info(f"task {task.key} already finished, skipping") + output = load(task.ray_dataset_path) + task._dataset_ref = ray.put(output) + return task._dataset_ref + + task = copy.copy(task) + task.input_deps = {dep_key: None for dep_key in task.input_deps} + + @ray.remote + def exec_task(task: Task, *inputs: DataSet) -> DataSet: + import multiprocessing as mp + import os + from pathlib import Path + + from loguru import logger + + # ray use a process pool to execute tasks + # we set the current process name to the task name + # so that we can see task name in the logs + mp.current_process().name = task.key + + # probe the retry count + task.retry_count = 0 + while os.path.exists(task.ray_marker_path): + task.retry_count += 1 + if task.retry_count > DEFAULT_MAX_RETRY_COUNT: + raise RuntimeError( + f"task {task.key} failed after {task.retry_count} retries" + ) + if task.retry_count > 0: + logger.warning( + f"task {task.key} is being retried for the {task.retry_count}th time" + ) + # create the marker file + Path(task.ray_marker_path).touch() + + # put the inputs into the task + assert len(inputs) == len(task.input_deps) + task.input_datasets = list(inputs) + # execute the task + status = task.exec() + if status != WorkStatus.SUCCEED: + raise task.exception or RuntimeError( + f"task {task.key} failed with status {status}" + ) + + # dump the output dataset atomically + os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True) + dump(task.output, task.ray_dataset_path, atomic_write=True) + return task.output + + # this shows as {"name": ...} in timeline + exec_task._function_name = repr(task) + + remote_function = exec_task.options( + # ray task name + # do not include task id so that they can be grouped by node in ray dashboard + name=f"{task.node_id}.{task.__class__.__name__}", + num_cpus=task.cpu_limit, + num_gpus=task.gpu_limit, + memory=int(task.memory_limit), + # note: `exec_on_scheduler` is ignored here, + # because dataset is distributed on ray + ) + try: + task._dataset_ref = remote_function.remote( + task, *[run_on_ray(dep) for dep in task.input_deps.values()] + ) + except RuntimeError as e: + if ( + "SimpleQueue objects should only be shared between processes through inheritance" + in str(e) + ): + raise RuntimeError( + f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n" + f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." + ) from e + raise e + return task._dataset_ref diff --git a/smallpond/execution/manager.py b/smallpond/execution/manager.py index c89b845..61a47c4 100644 --- a/smallpond/execution/manager.py +++ b/smallpond/execution/manager.py @@ -212,14 +212,10 @@ class JobManager(object): sched_state_observers.insert(0, sched_state_exporter) if os.path.exists(self.runtime_ctx.sched_state_path): - logger.warning( - f"loading scheduler state from: {self.runtime_ctx.sched_state_path}" - ) - self.scheduler: Scheduler = load(self.runtime_ctx.sched_state_path) - self.scheduler.sched_epoch += 1 - self.scheduler.sched_state_observers = sched_state_observers + self.scheduler = Scheduler.recover_from_file(self.runtime_ctx.sched_state_path, sched_state_observers) else: self.scheduler = Scheduler( + ctx=self.runtime_ctx, max_retry_count=max_retry_count, max_fail_count=max_fail_count, prioritize_retry=prioritize_retry, diff --git a/smallpond/execution/scheduler.py b/smallpond/execution/scheduler.py index e2e1057..8d39ac3 100644 --- a/smallpond/execution/scheduler.py +++ b/smallpond/execution/scheduler.py @@ -1,5 +1,6 @@ import copy import cProfile +from datetime import datetime import itertools import multiprocessing as mp import os @@ -12,7 +13,18 @@ from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from enum import Enum from functools import cached_property -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) import numpy as np from loguru import logger @@ -40,7 +52,7 @@ from smallpond.execution.workqueue import ( WorkQueueInMemory, WorkQueueOnFilesystem, ) -from smallpond.io.filesystem import dump, remove_path +from smallpond.io.filesystem import dump, load, remove_path from smallpond.logical.node import LogicalPlan, Node from smallpond.utility import cprofile_to_string @@ -359,10 +371,9 @@ class Scheduler(object): """ large_num_nontrivial_tasks = 200 if pytest_running() else 20000 - StateCallback = Callable[["Scheduler"], Any] class StateObserver(object): - def __init__(self, callback: "Scheduler.StateCallback" = None) -> None: + def __init__(self, callback: Callable[["Scheduler"], Any] = None) -> None: assert callback is None or isinstance(callback, Callable) self.enabled = True self.callback = callback @@ -380,6 +391,7 @@ class Scheduler(object): def __init__( self, + ctx: RuntimeContext, *, max_retry_count: int = DEFAULT_MAX_RETRY_COUNT, max_fail_count: int = DEFAULT_MAX_FAIL_COUNT, @@ -390,18 +402,46 @@ class Scheduler(object): remove_output_root: bool = False, sched_state_observers: Optional[List[StateObserver]] = None, ) -> None: + """ + Initialize the scheduler. + + Parameters + ---------- + ctx: RuntimeContext + The runtime context. + max_retry_count: int, optional + The maximum retry count. Default to 5. + max_fail_count: int, optional + The maximum fail count. Default to 3. + prioritize_retry: bool, optional + Whether to prioritize retry. Default to False. + speculative_exec: Literal["disable", "enable", "aggressive"], optional + If "enable", long-running tasks will be rescheduled on other executors. + If "aggressive", it will be more aggressive to reschedule long-running tasks. + If "disable", no speculative execution will be performed. + Default to "enable". + stop_executor_on_failure: bool, optional + Whether to stop the executor on failure. Default to False. + nonzero_exitcode_as_oom: bool, optional + Whether to treat non-zero exit code as out-of-memory. Default to False. + remove_output_root: bool, optional + Whether to remove the output root on exit. Default to False. + sched_state_observers: Optional[List[StateObserver]], optional + The state observers. + """ + # configs + self.ctx = ctx self.max_retry_count = max_retry_count self.max_fail_count = max_fail_count self.standalone_mode = self.ctx.num_executors == 0 self.prioritize_retry = prioritize_retry - self.disable_speculative_exec = speculative_exec == "disable" - self.aggressive_speculative_exec = speculative_exec == "aggressive" + self.speculative_exec: Literal["disable", "enable", "aggressive"] = ( + speculative_exec + ) self.stop_executor_on_failure = stop_executor_on_failure self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom self.remove_output_root = remove_output_root - self.sched_state_observers: List[Scheduler.StateObserver] = ( - sched_state_observers or [] - ) + self.sched_state_observers = sched_state_observers or [] self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2 # task states self.local_queue: List[Task] = [] @@ -416,30 +456,47 @@ class Scheduler(object): self.local_executor = LocalExecutor.create(self.ctx, "localhost") self.available_executors = {self.local_executor.id: self.local_executor} # other runtime states - self.sched_running = False - self.sched_start_time = 0 + self.failure = False + self.sched_start_time = time.time() self.last_executor_probe_time = 0 self.last_state_notify_time = 0 self.probe_epoch = 0 self.sched_epoch = 0 + self._post_init() + + @staticmethod + def recover_from_file( + sched_state_path: str, sched_state_observers: List[StateObserver] + ) -> "Scheduler": + """ + Recover the scheduler from the previous run. + """ + logger.warning(f"loading scheduler state from: {sched_state_path}") + self: Scheduler = load(sched_state_path) + + self.sched_epoch += 1 + self.prioritize_retry = True + # observers are not pickled, so we need to re-add them + self.sched_state_observers = sched_state_observers + + self._post_init() + def __getstate__(self): state = self.__dict__.copy() del state["sched_state_observers"] + del state["profiler"] return state def __setstate__(self, state): self.__dict__.update(state) self.sched_state_observers = [] + self.profiler = None @property def elapsed_time(self): return time.time() - self.sched_start_time - @property - def success(self) -> bool: - return self.exec_plan.root_task.key in self.succeeded_tasks - @property def progress(self) -> Tuple[int, int, float]: num_succeeded = len(self.succeeded_nontrivial_tasks) @@ -582,7 +639,7 @@ class Scheduler(object): for executor in self.working_executors: for idx, item in enumerate(executor.running_works.values()): aggressive_retry = ( - self.aggressive_speculative_exec + self.speculative_exec == "aggressive" and len(self.good_executors) >= self.ctx.num_executors ) short_sched_queue = len(self.sched_queue) < len(self.good_executors) @@ -644,7 +701,7 @@ class Scheduler(object): for executor in self.working_executors: executor.probe(self.probe_epoch) # start speculative execution of tasks - if not self.disable_speculative_exec: + if self.speculative_exec != "disable": self.start_speculative_execution() def update_executor_states(self): @@ -777,7 +834,7 @@ class Scheduler(object): return task @logger.catch(reraise=pytest_running(), message="failed to clean temp files") - def clean_temp_files(self, pool: ThreadPoolExecutor): + def clean_temp_files(self): remove_path(self.ctx.queue_root) remove_path(self.ctx.temp_root) remove_path(self.ctx.staging_root) @@ -786,7 +843,10 @@ class Scheduler(object): logger.info( f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..." ) - assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)) + with ThreadPoolExecutor(32) as pool: + assert list( + pool.map(lambda t: t.clean_output(force=True), abandoned_tasks) + ) @logger.catch(reraise=pytest_running(), message="failed to export task metrics") def export_task_metrics(self): @@ -963,55 +1023,45 @@ class Scheduler(object): def log_current_status(self): with open(self.ctx.job_status_path, "w") as fout: - if self.sched_running: - status = "running" - elif self.success: - status = "success" - else: + if self.failure: status = "failure" - fout.write(f"{status}@{int(time.time())}") + else: + status = "running" + fout.write(f"{status}@{datetime.now().isoformat()}") - def run(self, exec_plan: ExecutionPlan) -> bool: + def _post_init(self): """ - Run the execution plan. + Common initialization after startup or recovery. """ mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}" logger.info( f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}" ) - perf_profile = None if self.ctx.enable_profiling: - perf_profile = cProfile.Profile() - perf_profile.enable() + self.profiler = cProfile.Profile() + self.profiler.enable() + + self.sched_running = True + self.sched_start_time = time.time() + self.last_executor_probe_time = 0 + self.last_state_notify_time = 0 + self.prioritize_retry |= self.sched_epoch > 0 + + if self.local_queue or self.sched_queue: + pending_tasks = [ + item + for item in self.local_queue + self.sched_queue + if isinstance(item, Task) + ] + self.local_queue.clear() + self.sched_queue.clear() + logger.info( + f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..." + ) + self.try_enqueue(pending_tasks) with ThreadPoolExecutor(32) as pool: - self.sched_running = True - self.sched_start_time = time.time() - self.last_executor_probe_time = 0 - self.last_state_notify_time = 0 - self.prioritize_retry |= self.sched_epoch > 0 - - if self.local_queue or self.sched_queue: - pending_tasks = [ - item - for item in self.local_queue + self.sched_queue - if isinstance(item, Task) - ] - self.local_queue.clear() - self.sched_queue.clear() - logger.info( - f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..." - ) - self.try_enqueue(pending_tasks) - - if self.sched_epoch == 0: - leaf_tasks = self.exec_plan.leaves - logger.info( - f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..." - ) - self.try_enqueue(leaf_tasks) - self.log_overall_progress() while (num_finished_tasks := self.process_finished_tasks(pool)) > 0: logger.info( @@ -1019,54 +1069,64 @@ class Scheduler(object): ) self.log_overall_progress() - earlier_running_tasks = [ - item for item in self.running_works if isinstance(item, Task) - ] - if earlier_running_tasks: - logger.info( - f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..." - ) - self.try_enqueue(earlier_running_tasks) + earlier_running_tasks = [ + item for item in self.running_works if isinstance(item, Task) + ] + if earlier_running_tasks: + logger.info( + f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..." + ) + self.try_enqueue(earlier_running_tasks) - self.suspend_good_executors() - self.add_state_observer( - Scheduler.StateObserver(Scheduler.log_current_status) - ) - self.add_state_observer( - Scheduler.StateObserver(Scheduler.export_timeline_figs) - ) + self.suspend_good_executors() + self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status)) + self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs)) + self.notify_state_observers(force_notify=True) + + def run(self, exec_plan: ExecutionPlan) -> bool: + """ + Run the execution plan. + """ + leaf_tasks = exec_plan.leaves + logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...") + self.try_enqueue(leaf_tasks) + + try: + with ThreadPoolExecutor(32) as pool: + self.local_executor.start(pool) + self.sched_loop(pool, exec_plan.root_task) + finally: + logger.info(f"schedule loop stopped") self.notify_state_observers(force_notify=True) - try: - self.local_executor.start(pool) - self.sched_loop(pool) - finally: - logger.info(f"schedule loop stopped") - self.sched_running = False - self.notify_state_observers(force_notify=True) - self.export_task_metrics() - self.stop_executors() + logger.success(f"final output path: {exec_plan.final_output_path}") + logger.info( + f"analyzed plan:{os.linesep}{exec_plan.analyzed_logical_plan.explain_str()}" + ) - # if --output_path is specified, remove the output root as well - if self.remove_output_root or self.ctx.final_output_path: - remove_path(self.ctx.staging_root) - remove_path(self.ctx.output_root) + return True - if self.success: - self.clean_temp_files(pool) - logger.success(f"final output path: {self.exec_plan.final_output_path}") - logger.info( - f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}" - ) + def cleanup(self): + self.export_task_metrics() + self.stop_executors() - if perf_profile is not None: + # if --output_path is specified, remove the output root as well + if self.remove_output_root: + remove_path(self.ctx.staging_root) + remove_path(self.ctx.output_root) + + if not self.failure: + self.clean_temp_files() + if self.ctx.final_output_path: + remove_path(self.ctx.final_output_path) + + if self.profiler is not None: logger.debug( - f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}" + f"scheduler perf profile:{os.linesep}{cprofile_to_string(self.profiler)}" ) logger.info(f"scheduler of job {self.ctx.job_id} exits") logger.complete() - return self.success def try_enqueue(self, tasks: Union[Iterable[Task], Task]): tasks = tasks if isinstance(tasks, Iterable) else [tasks] @@ -1092,15 +1152,15 @@ class Scheduler(object): else: self.sched_queue.append(task) - def sched_loop(self, pool: ThreadPoolExecutor) -> bool: + def sched_loop(self, pool: ThreadPoolExecutor, task: Task): + """ + Run the scheduler loop until the task is finished or failed. + """ + has_progress = True do_notify = False - if self.success: - logger.success(f"job already succeeded, stopping scheduler ...") - return True - - while self.sched_running: + while not self.failure and self.tasks[task.key].status == WorkStatus.INCOMPLETE: self.probe_executors() self.update_executor_states() @@ -1153,9 +1213,6 @@ class Scheduler(object): has_progress |= self.process_finished_tasks(pool) > 0 - # out of loop - return self.success - def dispatch_tasks(self, pool: ThreadPoolExecutor): # sort pending tasks item_sort_key = ( @@ -1240,6 +1297,10 @@ class Scheduler(object): return num_dispatched_tasks def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int: + """ + Process finished tasks from all executors. + Return the number of finished tasks. + """ pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values()) num_finished_tasks = 0 @@ -1309,7 +1370,7 @@ class Scheduler(object): f"task failed too many times: {finished_task}, stopping ..." ) self.stop_executors() - self.sched_running = False + self.failure = True if not executor.local and finished_task.oom( self.nonzero_exitcode_as_oom diff --git a/smallpond/execution/task.py b/smallpond/execution/task.py index 43d8d22..dbfd41d 100644 --- a/smallpond/execution/task.py +++ b/smallpond/execution/task.py @@ -42,7 +42,6 @@ import pandas as pd import psutil import pyarrow as arrow import pyarrow.parquet as parquet -import ray from loguru import logger from smallpond.common import ( @@ -642,10 +641,6 @@ class Task(WorkItem): # implementor can use this variable as a checkpoint and restore from it after interrupted self.runtime_state = None - # if the task is executed by ray, this is the reference to the output dataset - # do not use this variable directly, use `self.run_on_ray()` instead - self._dataset_ref: Optional[ray.ObjectRef] = None - def __repr__(self) -> str: return f"{self.key}.{self.sched_epoch}.{self.retry_count},{self.node_id}" @@ -1104,97 +1099,6 @@ class Task(WorkItem): for name, value in metrics.items(): self.perf_metrics[name] += value - def run_on_ray(self) -> ray.ObjectRef: - """ - Run the task on Ray. - Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset. - """ - if self._dataset_ref is not None: - # already started - return self._dataset_ref - - # read the output dataset if the task has already finished - if os.path.exists(self.ray_dataset_path): - logger.info(f"task {self.key} already finished, skipping") - output = load(self.ray_dataset_path) - self._dataset_ref = ray.put(output) - return self._dataset_ref - - task = copy.copy(self) - task.input_deps = {dep_key: None for dep_key in task.input_deps} - - @ray.remote - def exec_task(task: Task, *inputs: DataSet) -> DataSet: - import multiprocessing as mp - import os - from pathlib import Path - - from loguru import logger - - # ray use a process pool to execute tasks - # we set the current process name to the task name - # so that we can see task name in the logs - mp.current_process().name = task.key - - # probe the retry count - task.retry_count = 0 - while os.path.exists(task.ray_marker_path): - task.retry_count += 1 - if task.retry_count > DEFAULT_MAX_RETRY_COUNT: - raise RuntimeError( - f"task {task.key} failed after {task.retry_count} retries" - ) - if task.retry_count > 0: - logger.warning( - f"task {task.key} is being retried for the {task.retry_count}th time" - ) - # create the marker file - Path(task.ray_marker_path).touch() - - # put the inputs into the task - assert len(inputs) == len(task.input_deps) - task.input_datasets = list(inputs) - # execute the task - status = task.exec() - if status != WorkStatus.SUCCEED: - raise task.exception or RuntimeError( - f"task {task.key} failed with status {status}" - ) - - # dump the output dataset atomically - os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True) - dump(task.output, task.ray_dataset_path, atomic_write=True) - return task.output - - # this shows as {"name": ...} in timeline - exec_task._function_name = repr(task) - - remote_function = exec_task.options( - # ray task name - # do not include task id so that they can be grouped by node in ray dashboard - name=f"{task.node_id}.{self.__class__.__name__}", - num_cpus=self.cpu_limit, - num_gpus=self.gpu_limit, - memory=int(self.memory_limit), - # note: `exec_on_scheduler` is ignored here, - # because dataset is distributed on ray - ) - try: - self._dataset_ref = remote_function.remote( - task, *[dep.run_on_ray() for dep in self.input_deps.values()] - ) - except RuntimeError as e: - if ( - "SimpleQueue objects should only be shared between processes through inheritance" - in str(e) - ): - raise RuntimeError( - f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n" - f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." - ) from e - raise e - return self._dataset_ref - class ExecSqlQueryMixin(Task): diff --git a/smallpond/session.py b/smallpond/session.py index 171be42..a07b26e 100644 --- a/smallpond/session.py +++ b/smallpond/session.py @@ -21,368 +21,7 @@ import graphviz.backend.execute from loguru import logger import smallpond +from smallpond.execution.manager import JobManager from smallpond.execution.task import JobId, RuntimeContext from smallpond.logical.node import Context from smallpond.platform import Platform, get_platform - - -class SessionBase: - def __init__(self, **kwargs): - """ - Create a smallpond environment. - """ - super().__init__() - self._ctx = Context() - self.config, self._platform = Config.from_args_and_env(**kwargs) - - # construct runtime context for Tasks - runtime_ctx = RuntimeContext( - job_id=JobId(hex=self.config.job_id), - job_time=self.config.job_time, - data_root=self.config.data_root, - num_executors=self.config.num_executors, - bind_numa_node=self.config.bind_numa_node, - shared_log_root=self._platform.shared_log_root(), - ) - self._runtime_ctx = runtime_ctx - - # if `spawn` is specified, spawn a job and exit - if os.environ.get("SP_SPAWN") == "1": - self._spawn_self() - exit(0) - - self._runtime_ctx.initialize(exec_id=socket.gethostname()) - logger.info(f"using platform: {self._platform}") - logger.info(f"command-line arguments: {' '.join(sys.argv)}") - logger.info(f"session config: {self.config}") - - def setup_worker(): - runtime_ctx._init_logs( - exec_id=socket.gethostname(), capture_stdout_stderr=True - ) - - if self.config.ray_address is None: - # find the memory allocator - if self.config.memory_allocator == "system": - malloc_path = "" - elif self.config.memory_allocator == "jemalloc": - malloc_path = shutil.which("libjemalloc.so.2") - assert malloc_path is not None, "jemalloc is not installed" - elif self.config.memory_allocator == "mimalloc": - malloc_path = shutil.which("libmimalloc.so.2.1") - assert malloc_path is not None, "mimalloc is not installed" - else: - raise ValueError( - f"unsupported memory allocator: {self.config.memory_allocator}" - ) - memory_purge_delay = 10000 - - # start ray head node - # for ray head node to access grafana - os.environ["RAY_GRAFANA_HOST"] = "http://localhost:8122" - self._ray_address = ray.init( - # start a new local cluster - address="local", - # disable local CPU resource if not running on localhost - num_cpus=( - 0 - if self.config.num_executors > 0 - else self._runtime_ctx.usable_cpu_count - ), - # set the memory limit to the available memory size - _memory=self._runtime_ctx.usable_memory_size, - # setup logging for workers - log_to_driver=False, - runtime_env={ - "worker_process_setup_hook": setup_worker, - "env_vars": { - "LD_PRELOAD": malloc_path, - "MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16", - "MIMALLOC_PURGE_DELAY": f"{memory_purge_delay}", - "ARROW_DEFAULT_MEMORY_POOL": self.config.memory_allocator, - "ARROW_IO_THREADS": "2", - "OMP_NUM_THREADS": "2", - "POLARS_MAX_THREADS": "2", - "NUMEXPR_MAX_THREADS": "2", - "RAY_PROFILING": "1", - }, - }, - dashboard_host="0.0.0.0", - dashboard_port=8008, - # for prometheus to scrape metrics - _metrics_export_port=8080, - ).address_info["gcs_address"] - logger.info(f"started ray cluster at {self._ray_address}") - - self._prometheus_process = self._start_prometheus() - self._grafana_process = self._start_grafana() - else: - self._ray_address = self.config.ray_address - self._prometheus_process = None - self._grafana_process = None - logger.info(f"connected to ray cluster at {self._ray_address}") - - # start workers - if self.config.num_executors > 0: - # override configs - kwargs["job_id"] = self.config.job_id - - self._job_names = self._platform.start_job( - self.config.num_executors, - entrypoint=os.path.join(os.path.dirname(__file__), "worker.py"), - args=[ - f"--ray_address={self._ray_address}", - f"--log_dir={self._runtime_ctx.log_root}", - *(["--bind_numa_node"] if self.config.bind_numa_node else []), - ], - extra_opts=kwargs, - ) - else: - self._job_names = [] - - # spawn a thread to periodically dump metrics - self._stop_event = threading.Event() - self._dump_thread = threading.Thread( - name="dump_thread", target=self._dump_periodically, daemon=True - ) - self._dump_thread.start() - - def shutdown(self): - """ - Shutdown the session. - """ - logger.info("shutting down session") - self._stop_event.set() - - # stop all jobs - for job_name in self._job_names: - self._platform.stop_job(job_name) - self._job_names = [] - - self._dump_thread.join() - if self.config.ray_address is None: - ray.shutdown() - if self._prometheus_process is not None: - self._prometheus_process.terminate() - self._prometheus_process.wait() - self._prometheus_process = None - logger.info("stopped prometheus") - if self._grafana_process is not None: - self._grafana_process.terminate() - self._grafana_process.wait() - self._grafana_process = None - logger.info("stopped grafana") - - def _spawn_self(self): - """ - Spawn a new job to run the current script. - """ - self._platform.start_job( - num_nodes=1, - entrypoint=sys.argv[0], - args=sys.argv[1:], - extra_opts=dict( - tags=["smallpond", "scheduler", smallpond.__version__], - ), - envs={ - k: v - for k, v in os.environ.items() - if k.startswith("SP_") and k != "SP_SPAWN" - }, - ) - - def _start_prometheus(self) -> Optional[subprocess.Popen]: - """ - Start prometheus server if it exists. - """ - prometheus_path = shutil.which("prometheus") - if prometheus_path is None: - logger.warning("prometheus is not found") - return None - os.makedirs(f"{self._runtime_ctx.log_root}/prometheus", exist_ok=True) - proc = subprocess.Popen( - [ - prometheus_path, - "--config.file=/tmp/ray/session_latest/metrics/prometheus/prometheus.yml", - f"--storage.tsdb.path={self._runtime_ctx.log_root}/prometheus/data", - ], - stderr=open(f"{self._runtime_ctx.log_root}/prometheus/prometheus.log", "w"), - ) - logger.info("started prometheus") - return proc - - def _start_grafana(self) -> Optional[subprocess.Popen]: - """ - Start grafana server if it exists. - """ - homepath = self._platform.grafana_homepath() - if homepath is None: - logger.warning("grafana is not found") - return None - os.makedirs(f"{self._runtime_ctx.log_root}/grafana", exist_ok=True) - proc = subprocess.Popen( - [ - shutil.which("grafana"), - "server", - "--config", - "/tmp/ray/session_latest/metrics/grafana/grafana.ini", - "-homepath", - homepath, - "web", - ], - stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"), - env={ - "GF_SERVER_HTTP_PORT": "8122", # redirect to an available port - "GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST") - or "http://localhost:8122", - "GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data", - }, - ) - logger.info(f"started grafana at http://localhost:8122") - return proc - - @property - def runtime_ctx(self) -> RuntimeContext: - return self._runtime_ctx - - def graph(self) -> Digraph: - """ - Get the logical plan graph. - """ - # implemented in Session class - raise NotImplementedError("graph") - - def dump_graph(self, path: Optional[str] = None): - """ - Dump the logical plan graph to a file. - """ - path = path or os.path.join(self.runtime_ctx.log_root, "graph") - try: - self.graph().render(path, format="png") - logger.debug(f"dumped graph to {path}") - except graphviz.backend.execute.ExecutableNotFound as e: - logger.warning(f"graphviz is not installed, skipping graph dump") - - def dump_timeline(self, path: Optional[str] = None): - """ - Dump the task timeline to a file. - """ - path = path or os.path.join(self.runtime_ctx.log_root, "timeline") - # the default timeline is grouped by worker - exec_path = f"{path}_exec" - ray.timeline(exec_path) - logger.debug(f"dumped timeline to {exec_path}") - - # generate another timeline grouped by node - with open(exec_path) as f: - records = json.load(f) - new_records = [] - for record in records: - # swap record name and pid-tid - name = record["name"] - try: - node_id = name.split(",")[-1] - task_id = name.split("-")[1].split(".")[0] - task_name = name.split("-")[0] - record["pid"] = f"{node_id}-{task_name}" - record["tid"] = f"task {task_id}" - new_records.append(record) - except Exception: - # filter out other records - pass - node_path = f"{path}_plan" - with open(node_path, "w") as f: - json.dump(new_records, f) - logger.debug(f"dumped timeline to {node_path}") - - def _summarize_task(self) -> Tuple[int, int]: - # implemented in Session class - raise NotImplementedError("summarize_task") - - def _dump_periodically(self): - """ - Dump the graph and timeline every minute. - Set `self._stop_event` to have a final dump and stop this thread. - """ - while not self._stop_event.is_set(): - self._stop_event.wait(60) - self.dump_graph() - self.dump_timeline() - num_total_tasks, num_finished_tasks = self._summarize_task() - percent = ( - num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0 - ) - logger.info( - f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)" - ) - - -@dataclass -class Config: - """ - Configuration for a session. - """ - - job_id: str # JOBID - job_time: datetime # JOB_TIME - data_root: str # DATA_ROOT - num_executors: int # NUM_NODES_TOTAL - ray_address: Optional[str] # RAY_ADDRESS - bind_numa_node: bool # BIND_NUMA_NODE - memory_allocator: str # MEMORY_ALLOCATOR - remove_output_root: bool - - @staticmethod - def from_args_and_env( - platform: Optional[str] = None, - job_id: Optional[str] = None, - job_time: Optional[datetime] = None, - data_root: Optional[str] = None, - num_executors: Optional[int] = None, - ray_address: Optional[str] = None, - bind_numa_node: Optional[bool] = None, - memory_allocator: Optional[str] = None, - _remove_output_root: bool = True, - **kwargs, - ) -> Config: - """ - Load config from arguments and environment variables. - If not specified, use the default value. - """ - - def get_env(key: str, type: type = str): - """ - Get an environment variable and convert it to the given type. - If the variable is not set, return None. - """ - value = os.environ.get(f"SP_{key}") - return type(value) if value is not None else None - - platform = get_platform(get_env("PLATFORM") or platform) - job_id = get_env("JOBID") or job_id or platform.default_job_id() - job_time = ( - get_env("JOB_TIME", datetime.fromisoformat) - or job_time - or platform.default_job_time() - ) - data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root() - num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0 - ray_address = get_env("RAY_ADDRESS") or ray_address - bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node - memory_allocator = ( - get_env("MEMORY_ALLOCATOR") - or memory_allocator - or platform.default_memory_allocator() - ) - - config = Config( - job_id=job_id, - job_time=job_time, - data_root=data_root, - num_executors=num_executors, - ray_address=ray_address, - bind_numa_node=bind_numa_node, - memory_allocator=memory_allocator, - remove_output_root=_remove_output_root, - ) - return config, platform diff --git a/tests/test_fabric.py b/tests/test_fabric.py index 01c8b16..32bf0f7 100644 --- a/tests/test_fabric.py +++ b/tests/test_fabric.py @@ -210,6 +210,7 @@ class TestFabric(unittest.TestCase): self.sched_states = self.queue_manager.Queue() scheduler = Scheduler( + ctx=runtime_ctx, max_retry_count=max_retry_count, max_fail_count=max_fail_count, prioritize_retry=prioritize_retry,