DeepSeek-Prover-V1.5/prover/launch.py

96 lines
2.8 KiB
Python
Raw Permalink Normal View History

2024-08-16 03:33:21 +00:00
import os
import copy
import time
import warnings
import argparse
import torch
from prover.workers import DataLoader, Scheduler, ProcessScheduler, GeneratorProcess, SearchProcess
from prover.lean.verifier import Lean4ServerScheduler
from prover.utils import get_datetime, load_config, AttrDict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--log_dir", type=str, default=f'logs/{get_datetime()}')
parser.add_argument("--node_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
args = parser.parse_args()
cfg = load_config(args.config)
os.makedirs(args.log_dir, exist_ok=True)
ngpus = torch.cuda.device_count()
assert ngpus >= 1
# create data loader
data_loader = DataLoader(
data_path=cfg.data_path,
data_split=cfg.get('data_split', None),
data_repeat=cfg.get('data_repeat', 1),
node_rank=args.node_rank,
world_size=args.world_size,
log_dir=args.log_dir,
)
# build Lean verifier
verifier_scheduler = Lean4ServerScheduler(
max_concurrent_requests=cfg.lean_max_concurrent_requests,
memory_limit=cfg.lean_memory_limit,
timeout=cfg.lean_timeout,
name='verifier',
)
# load LLM models on gpus
generator_scheduler = ProcessScheduler(batch_size=cfg.batch_size, name='generator')
llm_processes = [
GeneratorProcess(
local_rank=local_rank,
node_rank=args.node_rank,
model_path=cfg.model_path,
task_queue=generator_scheduler.task_queue,
request_statuses=generator_scheduler.request_statuses,
lock=generator_scheduler.lock,
args=cfg.model_args,
)
for local_rank in range(ngpus)
]
# create a unified scheduler interface
scheduler = Scheduler(dict(
verifier=verifier_scheduler,
generator=generator_scheduler,
))
# launch search processes
search_processes = [
SearchProcess(
idx=i+args.node_rank*cfg.n_search_procs,
log_dir=args.log_dir,
tokenizer_path=cfg.model_path,
scheduler=scheduler,
data_loader=data_loader,
cfg=cfg,
)
for i in range(min(cfg.n_search_procs, data_loader.size()))
]
for p in search_processes:
p.start()
print(f'Complete launching {len(search_processes)} SearchProcesses')
for p in llm_processes:
p.start()
print(f'Complete launching {len(llm_processes)} LLMProcesses')
for p in search_processes:
p.join()
print(f'All {len(search_processes)} SearchProcesses stopped')
scheduler.close()
for p in llm_processes:
p.join()
print(f'All {len(llm_processes)} LLMProcesses stopped')