JobManager: move all arguments to __init__

Signed-off-by: Runji Wang <runji@deepseek.com>
This commit is contained in:
Runji Wang 2025-02-28 20:03:55 +08:00
parent ed112db42a
commit 5a56a052bf
4 changed files with 79 additions and 82 deletions

View File

@ -372,10 +372,10 @@ class Driver(object):
elif args.mode == "scheduler":
assert plan is not None
jobmgr = JobManager(
args.data_root, args.python_venv, args.task_image, args.platform
)
exec_plan = jobmgr.run(
plan,
data_root=args.data_root,
python_venv=args.python_venv,
task_image=args.task_image,
platform=args.platform,
job_id=args.job_id,
job_time=args.job_time,
job_name=args.job_name,
@ -412,6 +412,7 @@ class Driver(object):
disable_log_rotation=args.disable_log_rotation,
output_path=args.output_path,
)
exec_plan = jobmgr.run(plan)
retval = exec_plan if exec_plan.successful else None
if stop_process_on_done:

View File

@ -52,20 +52,11 @@ class JobManager(object):
def __init__(
self,
*,
data_root: Optional[str] = None,
python_venv: Optional[str] = None,
task_image: str = "default",
platform: Optional[str] = None,
) -> None:
self.platform = get_platform(platform)
self.data_root = os.path.abspath(data_root or self.platform.default_data_root())
self.python_venv = python_venv
self.task_image = task_image
@logger.catch(reraise=True, message="job manager terminated unexpectedly")
def run(
self,
plan: LogicalPlan,
job_id: Optional[str] = None,
job_time: Optional[float] = None,
job_name: str = "smallpond",
@ -107,9 +98,15 @@ class JobManager(object):
sched_state_observers: Optional[List[Scheduler.StateObserver]] = None,
output_path: Optional[str] = None,
**kwargs,
) -> ExecutionPlan:
) -> None:
self.platform = get_platform(platform)
logger.info(f"using platform: {self.platform}")
self.data_root = os.path.abspath(data_root or self.platform.default_data_root())
self.python_venv = python_venv
self.task_image = task_image
self.manifest_only_final_results = manifest_only_final_results
job_id = JobId(hex=job_id or self.platform.default_job_id())
job_time = (
datetime.fromtimestamp(job_time)
@ -145,7 +142,7 @@ class JobManager(object):
self.platform.shared_log_root() if share_log_analytics else None
)
runtime_ctx = RuntimeContext(
self.runtime_ctx = RuntimeContext(
job_id,
job_time,
self.data_root,
@ -166,27 +163,13 @@ class JobManager(object):
output_path=output_path,
**kwargs,
)
runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
self.runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
logger.info(
f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}"
)
dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True)
logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}")
dump(plan, runtime_ctx.logcial_plan_path, atomic_write=True)
logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}")
plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png")
logger.info(
f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png"
)
exec_plan = Planner(runtime_ctx).create_exec_plan(
plan, manifest_only_final_results
)
dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True)
logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}")
dump(self.runtime_ctx, self.runtime_ctx.runtime_ctx_path, atomic_write=True)
logger.info(f"saved runtime context at {self.runtime_ctx.runtime_ctx_path}")
sidecar_list = sidecars or []
@ -225,27 +208,26 @@ class JobManager(object):
sched_state_observers = sched_state_observers or []
if enable_sched_state_dump:
sched_state_exporter = SchedStateExporter(runtime_ctx.sched_state_path)
sched_state_exporter = SchedStateExporter(self.runtime_ctx.sched_state_path)
sched_state_observers.insert(0, sched_state_exporter)
if os.path.exists(runtime_ctx.sched_state_path):
if os.path.exists(self.runtime_ctx.sched_state_path):
logger.warning(
f"loading scheduler state from: {runtime_ctx.sched_state_path}"
f"loading scheduler state from: {self.runtime_ctx.sched_state_path}"
)
scheduler: Scheduler = load(runtime_ctx.sched_state_path)
scheduler.sched_epoch += 1
scheduler.sched_state_observers = sched_state_observers
self.scheduler: Scheduler = load(self.runtime_ctx.sched_state_path)
self.scheduler.sched_epoch += 1
self.scheduler.sched_state_observers = sched_state_observers
else:
scheduler = Scheduler(
exec_plan,
max_retry_count,
max_fail_count,
prioritize_retry,
speculative_exec,
stop_executor_on_failure,
nonzero_exitcode_as_oom,
remove_output_root,
sched_state_observers,
self.scheduler = Scheduler(
max_retry_count=max_retry_count,
max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry,
speculative_exec=speculative_exec,
stop_executor_on_failure=stop_executor_on_failure,
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
remove_output_root=remove_output_root,
sched_state_observers=sched_state_observers,
)
# start executors
self.platform.start_job(
@ -257,7 +239,7 @@ class JobManager(object):
"--data_root",
self.data_root,
"--runtime_ctx_path",
runtime_ctx.runtime_ctx_path,
self.runtime_ctx.runtime_ctx_path,
"executor",
],
envs={
@ -278,6 +260,25 @@ class JobManager(object):
),
)
@logger.catch(reraise=True, message="job manager terminated unexpectedly")
def run(
self,
plan: LogicalPlan,
) -> ExecutionPlan:
dump(plan, self.runtime_ctx.logcial_plan_path, atomic_write=True)
logger.info(f"saved logcial plan at {self.runtime_ctx.logcial_plan_path}")
plan.graph().render(self.runtime_ctx.logcial_plan_graph_path, format="png")
logger.info(
f"saved logcial plan graph at {self.runtime_ctx.logcial_plan_graph_path}.png"
)
exec_plan = Planner(self.runtime_ctx).create_exec_plan(
plan, self.manifest_only_final_results
)
dump(exec_plan, self.runtime_ctx.exec_plan_path, atomic_write=True)
logger.info(f"saved execution plan at {self.runtime_ctx.exec_plan_path}")
# run scheduler
scheduler.run()
return scheduler.exec_plan
self.scheduler.run(exec_plan)
return exec_plan

View File

@ -12,7 +12,7 @@ from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Literal, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
import numpy as np
from loguru import logger
@ -380,20 +380,16 @@ class Scheduler(object):
def __init__(
self,
exec_plan: ExecutionPlan,
*,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
prioritize_retry=False,
prioritize_retry: bool = False,
speculative_exec: Literal["disable", "enable", "aggressive"] = "enable",
stop_executor_on_failure=False,
nonzero_exitcode_as_oom=False,
remove_output_root=False,
sched_state_observers=None,
stop_executor_on_failure: bool = False,
nonzero_exitcode_as_oom: bool = False,
remove_output_root: bool = False,
sched_state_observers: Optional[List[StateObserver]] = None,
) -> None:
self.ctx = exec_plan.ctx
self.exec_plan = exec_plan
self.logical_plan: LogicalPlan = self.exec_plan.logical_plan
self.logical_nodes = self.logical_plan.nodes
self.max_retry_count = max_retry_count
self.max_fail_count = max_fail_count
self.standalone_mode = self.ctx.num_executors == 0
@ -410,16 +406,12 @@ class Scheduler(object):
# task states
self.local_queue: List[Task] = []
self.sched_queue: List[Task] = []
self.tasks: Dict[str, Task] = self.exec_plan.tasks
self.scheduled_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
self.finished_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
self.succeeded_tasks: Dict[str, Task] = OrderedDict()
self.nontrivial_tasks = dict(
(key, task)
for (key, task) in self.tasks.items()
if not task.exec_on_scheduler
)
self.succeeded_nontrivial_tasks: Dict[str, Task] = OrderedDict()
self.tasks: Dict[str, Task] = {}
self.scheduled_tasks: Dict[TaskRuntimeId, Task] = {}
self.finished_tasks: Dict[TaskRuntimeId, Task] = {}
self.succeeded_tasks: Dict[str, Task] = {}
self.nontrivial_tasks: Dict[str, Task] = {}
self.succeeded_nontrivial_tasks: Dict[str, Task] = {}
# executor pool
self.local_executor = LocalExecutor.create(self.ctx, "localhost")
self.available_executors = {self.local_executor.id: self.local_executor}
@ -431,9 +423,6 @@ class Scheduler(object):
self.probe_epoch = 0
self.sched_epoch = 0
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.run()
def __getstate__(self):
state = self.__dict__.copy()
del state["sched_state_observers"]
@ -982,7 +971,10 @@ class Scheduler(object):
status = "failure"
fout.write(f"{status}@{int(time.time())}")
def run(self) -> bool:
def run(self, exec_plan: ExecutionPlan) -> bool:
"""
Run the execution plan.
"""
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
logger.info(
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"

View File

@ -28,11 +28,14 @@ generate_data()
def run_scheduler(
runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue
runtime_ctx: RuntimeContext,
scheduler: Scheduler,
queue: queue.Queue,
exec_plan: ExecutionPlan,
):
runtime_ctx.initialize("scheduler")
scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue)))
retval = scheduler.run()
retval = scheduler.run(exec_plan)
print(f"scheduler exited with value {retval}", file=sys.stderr)
@ -206,9 +209,7 @@ class TestFabric(unittest.TestCase):
self.queue_manager = Manager()
self.sched_states = self.queue_manager.Queue()
exec_plan = Planner(runtime_ctx).create_exec_plan(plan)
scheduler = Scheduler(
exec_plan,
max_retry_count=max_retry_count,
max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry,
@ -216,6 +217,8 @@ class TestFabric(unittest.TestCase):
stop_executor_on_failure=stop_executor_on_failure,
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
)
exec_plan = Planner(runtime_ctx).create_exec_plan(plan)
self.latest_state = scheduler
self.executors = [
Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)
@ -225,7 +228,7 @@ class TestFabric(unittest.TestCase):
target=run_scheduler,
# XXX: on macOS, scheduler state observer will be cleared when cross-process
# so we pass the queue and add the observer in the new process
args=(runtime_ctx, scheduler, self.sched_states),
args=(runtime_ctx, scheduler, self.sched_states, exec_plan),
name="scheduler",
)
]