mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
JobManager: move all arguments to __init__
Signed-off-by: Runji Wang <runji@deepseek.com>
This commit is contained in:
parent
ed112db42a
commit
5a56a052bf
@ -372,10 +372,10 @@ class Driver(object):
|
||||
elif args.mode == "scheduler":
|
||||
assert plan is not None
|
||||
jobmgr = JobManager(
|
||||
args.data_root, args.python_venv, args.task_image, args.platform
|
||||
)
|
||||
exec_plan = jobmgr.run(
|
||||
plan,
|
||||
data_root=args.data_root,
|
||||
python_venv=args.python_venv,
|
||||
task_image=args.task_image,
|
||||
platform=args.platform,
|
||||
job_id=args.job_id,
|
||||
job_time=args.job_time,
|
||||
job_name=args.job_name,
|
||||
@ -412,6 +412,7 @@ class Driver(object):
|
||||
disable_log_rotation=args.disable_log_rotation,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
exec_plan = jobmgr.run(plan)
|
||||
retval = exec_plan if exec_plan.successful else None
|
||||
|
||||
if stop_process_on_done:
|
||||
|
@ -52,20 +52,11 @@ class JobManager(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_root: Optional[str] = None,
|
||||
python_venv: Optional[str] = None,
|
||||
task_image: str = "default",
|
||||
platform: Optional[str] = None,
|
||||
) -> None:
|
||||
self.platform = get_platform(platform)
|
||||
self.data_root = os.path.abspath(data_root or self.platform.default_data_root())
|
||||
self.python_venv = python_venv
|
||||
self.task_image = task_image
|
||||
|
||||
@logger.catch(reraise=True, message="job manager terminated unexpectedly")
|
||||
def run(
|
||||
self,
|
||||
plan: LogicalPlan,
|
||||
job_id: Optional[str] = None,
|
||||
job_time: Optional[float] = None,
|
||||
job_name: str = "smallpond",
|
||||
@ -107,9 +98,15 @@ class JobManager(object):
|
||||
sched_state_observers: Optional[List[Scheduler.StateObserver]] = None,
|
||||
output_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> ExecutionPlan:
|
||||
) -> None:
|
||||
self.platform = get_platform(platform)
|
||||
logger.info(f"using platform: {self.platform}")
|
||||
|
||||
self.data_root = os.path.abspath(data_root or self.platform.default_data_root())
|
||||
self.python_venv = python_venv
|
||||
self.task_image = task_image
|
||||
self.manifest_only_final_results = manifest_only_final_results
|
||||
|
||||
job_id = JobId(hex=job_id or self.platform.default_job_id())
|
||||
job_time = (
|
||||
datetime.fromtimestamp(job_time)
|
||||
@ -145,7 +142,7 @@ class JobManager(object):
|
||||
self.platform.shared_log_root() if share_log_analytics else None
|
||||
)
|
||||
|
||||
runtime_ctx = RuntimeContext(
|
||||
self.runtime_ctx = RuntimeContext(
|
||||
job_id,
|
||||
job_time,
|
||||
self.data_root,
|
||||
@ -166,27 +163,13 @@ class JobManager(object):
|
||||
output_path=output_path,
|
||||
**kwargs,
|
||||
)
|
||||
runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
|
||||
self.runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
|
||||
logger.info(
|
||||
f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}"
|
||||
)
|
||||
|
||||
dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True)
|
||||
logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}")
|
||||
|
||||
dump(plan, runtime_ctx.logcial_plan_path, atomic_write=True)
|
||||
logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}")
|
||||
|
||||
plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png")
|
||||
logger.info(
|
||||
f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png"
|
||||
)
|
||||
|
||||
exec_plan = Planner(runtime_ctx).create_exec_plan(
|
||||
plan, manifest_only_final_results
|
||||
)
|
||||
dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True)
|
||||
logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}")
|
||||
dump(self.runtime_ctx, self.runtime_ctx.runtime_ctx_path, atomic_write=True)
|
||||
logger.info(f"saved runtime context at {self.runtime_ctx.runtime_ctx_path}")
|
||||
|
||||
sidecar_list = sidecars or []
|
||||
|
||||
@ -225,27 +208,26 @@ class JobManager(object):
|
||||
sched_state_observers = sched_state_observers or []
|
||||
|
||||
if enable_sched_state_dump:
|
||||
sched_state_exporter = SchedStateExporter(runtime_ctx.sched_state_path)
|
||||
sched_state_exporter = SchedStateExporter(self.runtime_ctx.sched_state_path)
|
||||
sched_state_observers.insert(0, sched_state_exporter)
|
||||
|
||||
if os.path.exists(runtime_ctx.sched_state_path):
|
||||
if os.path.exists(self.runtime_ctx.sched_state_path):
|
||||
logger.warning(
|
||||
f"loading scheduler state from: {runtime_ctx.sched_state_path}"
|
||||
f"loading scheduler state from: {self.runtime_ctx.sched_state_path}"
|
||||
)
|
||||
scheduler: Scheduler = load(runtime_ctx.sched_state_path)
|
||||
scheduler.sched_epoch += 1
|
||||
scheduler.sched_state_observers = sched_state_observers
|
||||
self.scheduler: Scheduler = load(self.runtime_ctx.sched_state_path)
|
||||
self.scheduler.sched_epoch += 1
|
||||
self.scheduler.sched_state_observers = sched_state_observers
|
||||
else:
|
||||
scheduler = Scheduler(
|
||||
exec_plan,
|
||||
max_retry_count,
|
||||
max_fail_count,
|
||||
prioritize_retry,
|
||||
speculative_exec,
|
||||
stop_executor_on_failure,
|
||||
nonzero_exitcode_as_oom,
|
||||
remove_output_root,
|
||||
sched_state_observers,
|
||||
self.scheduler = Scheduler(
|
||||
max_retry_count=max_retry_count,
|
||||
max_fail_count=max_fail_count,
|
||||
prioritize_retry=prioritize_retry,
|
||||
speculative_exec=speculative_exec,
|
||||
stop_executor_on_failure=stop_executor_on_failure,
|
||||
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
|
||||
remove_output_root=remove_output_root,
|
||||
sched_state_observers=sched_state_observers,
|
||||
)
|
||||
# start executors
|
||||
self.platform.start_job(
|
||||
@ -257,7 +239,7 @@ class JobManager(object):
|
||||
"--data_root",
|
||||
self.data_root,
|
||||
"--runtime_ctx_path",
|
||||
runtime_ctx.runtime_ctx_path,
|
||||
self.runtime_ctx.runtime_ctx_path,
|
||||
"executor",
|
||||
],
|
||||
envs={
|
||||
@ -278,6 +260,25 @@ class JobManager(object):
|
||||
),
|
||||
)
|
||||
|
||||
@logger.catch(reraise=True, message="job manager terminated unexpectedly")
|
||||
def run(
|
||||
self,
|
||||
plan: LogicalPlan,
|
||||
) -> ExecutionPlan:
|
||||
dump(plan, self.runtime_ctx.logcial_plan_path, atomic_write=True)
|
||||
logger.info(f"saved logcial plan at {self.runtime_ctx.logcial_plan_path}")
|
||||
|
||||
plan.graph().render(self.runtime_ctx.logcial_plan_graph_path, format="png")
|
||||
logger.info(
|
||||
f"saved logcial plan graph at {self.runtime_ctx.logcial_plan_graph_path}.png"
|
||||
)
|
||||
|
||||
exec_plan = Planner(self.runtime_ctx).create_exec_plan(
|
||||
plan, self.manifest_only_final_results
|
||||
)
|
||||
dump(exec_plan, self.runtime_ctx.exec_plan_path, atomic_write=True)
|
||||
logger.info(f"saved execution plan at {self.runtime_ctx.exec_plan_path}")
|
||||
|
||||
# run scheduler
|
||||
scheduler.run()
|
||||
return scheduler.exec_plan
|
||||
self.scheduler.run(exec_plan)
|
||||
return exec_plan
|
||||
|
@ -12,7 +12,7 @@ from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@ -380,20 +380,16 @@ class Scheduler(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exec_plan: ExecutionPlan,
|
||||
*,
|
||||
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
|
||||
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
|
||||
prioritize_retry=False,
|
||||
prioritize_retry: bool = False,
|
||||
speculative_exec: Literal["disable", "enable", "aggressive"] = "enable",
|
||||
stop_executor_on_failure=False,
|
||||
nonzero_exitcode_as_oom=False,
|
||||
remove_output_root=False,
|
||||
sched_state_observers=None,
|
||||
stop_executor_on_failure: bool = False,
|
||||
nonzero_exitcode_as_oom: bool = False,
|
||||
remove_output_root: bool = False,
|
||||
sched_state_observers: Optional[List[StateObserver]] = None,
|
||||
) -> None:
|
||||
self.ctx = exec_plan.ctx
|
||||
self.exec_plan = exec_plan
|
||||
self.logical_plan: LogicalPlan = self.exec_plan.logical_plan
|
||||
self.logical_nodes = self.logical_plan.nodes
|
||||
self.max_retry_count = max_retry_count
|
||||
self.max_fail_count = max_fail_count
|
||||
self.standalone_mode = self.ctx.num_executors == 0
|
||||
@ -410,16 +406,12 @@ class Scheduler(object):
|
||||
# task states
|
||||
self.local_queue: List[Task] = []
|
||||
self.sched_queue: List[Task] = []
|
||||
self.tasks: Dict[str, Task] = self.exec_plan.tasks
|
||||
self.scheduled_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
|
||||
self.finished_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
|
||||
self.succeeded_tasks: Dict[str, Task] = OrderedDict()
|
||||
self.nontrivial_tasks = dict(
|
||||
(key, task)
|
||||
for (key, task) in self.tasks.items()
|
||||
if not task.exec_on_scheduler
|
||||
)
|
||||
self.succeeded_nontrivial_tasks: Dict[str, Task] = OrderedDict()
|
||||
self.tasks: Dict[str, Task] = {}
|
||||
self.scheduled_tasks: Dict[TaskRuntimeId, Task] = {}
|
||||
self.finished_tasks: Dict[TaskRuntimeId, Task] = {}
|
||||
self.succeeded_tasks: Dict[str, Task] = {}
|
||||
self.nontrivial_tasks: Dict[str, Task] = {}
|
||||
self.succeeded_nontrivial_tasks: Dict[str, Task] = {}
|
||||
# executor pool
|
||||
self.local_executor = LocalExecutor.create(self.ctx, "localhost")
|
||||
self.available_executors = {self.local_executor.id: self.local_executor}
|
||||
@ -431,9 +423,6 @@ class Scheduler(object):
|
||||
self.probe_epoch = 0
|
||||
self.sched_epoch = 0
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.run()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state["sched_state_observers"]
|
||||
@ -982,7 +971,10 @@ class Scheduler(object):
|
||||
status = "failure"
|
||||
fout.write(f"{status}@{int(time.time())}")
|
||||
|
||||
def run(self) -> bool:
|
||||
def run(self, exec_plan: ExecutionPlan) -> bool:
|
||||
"""
|
||||
Run the execution plan.
|
||||
"""
|
||||
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
|
||||
logger.info(
|
||||
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"
|
||||
|
@ -28,11 +28,14 @@ generate_data()
|
||||
|
||||
|
||||
def run_scheduler(
|
||||
runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue
|
||||
runtime_ctx: RuntimeContext,
|
||||
scheduler: Scheduler,
|
||||
queue: queue.Queue,
|
||||
exec_plan: ExecutionPlan,
|
||||
):
|
||||
runtime_ctx.initialize("scheduler")
|
||||
scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue)))
|
||||
retval = scheduler.run()
|
||||
retval = scheduler.run(exec_plan)
|
||||
print(f"scheduler exited with value {retval}", file=sys.stderr)
|
||||
|
||||
|
||||
@ -206,9 +209,7 @@ class TestFabric(unittest.TestCase):
|
||||
self.queue_manager = Manager()
|
||||
self.sched_states = self.queue_manager.Queue()
|
||||
|
||||
exec_plan = Planner(runtime_ctx).create_exec_plan(plan)
|
||||
scheduler = Scheduler(
|
||||
exec_plan,
|
||||
max_retry_count=max_retry_count,
|
||||
max_fail_count=max_fail_count,
|
||||
prioritize_retry=prioritize_retry,
|
||||
@ -216,6 +217,8 @@ class TestFabric(unittest.TestCase):
|
||||
stop_executor_on_failure=stop_executor_on_failure,
|
||||
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
|
||||
)
|
||||
exec_plan = Planner(runtime_ctx).create_exec_plan(plan)
|
||||
|
||||
self.latest_state = scheduler
|
||||
self.executors = [
|
||||
Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)
|
||||
@ -225,7 +228,7 @@ class TestFabric(unittest.TestCase):
|
||||
target=run_scheduler,
|
||||
# XXX: on macOS, scheduler state observer will be cleared when cross-process
|
||||
# so we pass the queue and add the observer in the new process
|
||||
args=(runtime_ctx, scheduler, self.sched_states),
|
||||
args=(runtime_ctx, scheduler, self.sched_states, exec_plan),
|
||||
name="scheduler",
|
||||
)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user