import os import copy import time import warnings import argparse import torch from prover.workers import DataLoader, Scheduler, ProcessScheduler, GeneratorProcess, SearchProcess from prover.lean.verifier import Lean4ServerScheduler from prover.utils import get_datetime, load_config, AttrDict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str) parser.add_argument("--log_dir", type=str, default=f'logs/{get_datetime()}') parser.add_argument("--node_rank", type=int, default=0) parser.add_argument("--world_size", type=int, default=1) args = parser.parse_args() cfg = load_config(args.config) os.makedirs(args.log_dir, exist_ok=True) ngpus = torch.cuda.device_count() assert ngpus >= 1 # create data loader data_loader = DataLoader( data_path=cfg.data_path, data_split=cfg.get('data_split', None), data_repeat=cfg.get('data_repeat', 1), node_rank=args.node_rank, world_size=args.world_size, log_dir=args.log_dir, ) # build Lean verifier verifier_scheduler = Lean4ServerScheduler( max_concurrent_requests=cfg.lean_max_concurrent_requests, memory_limit=cfg.lean_memory_limit, timeout=cfg.lean_timeout, name='verifier', ) # load LLM models on gpus generator_scheduler = ProcessScheduler(batch_size=cfg.batch_size, name='generator') llm_processes = [ GeneratorProcess( local_rank=local_rank, node_rank=args.node_rank, model_path=cfg.model_path, task_queue=generator_scheduler.task_queue, request_statuses=generator_scheduler.request_statuses, lock=generator_scheduler.lock, args=cfg.model_args, ) for local_rank in range(ngpus) ] # create a unified scheduler interface scheduler = Scheduler(dict( verifier=verifier_scheduler, generator=generator_scheduler, )) # launch search processes search_processes = [ SearchProcess( idx=i+args.node_rank*cfg.n_search_procs, log_dir=args.log_dir, tokenizer_path=cfg.model_path, scheduler=scheduler, data_loader=data_loader, cfg=cfg, ) for i in range(min(cfg.n_search_procs, data_loader.size())) ] for p in search_processes: p.start() print(f'Complete launching {len(search_processes)} SearchProcesses') for p in llm_processes: p.start() print(f'Complete launching {len(llm_processes)} LLMProcesses') for p in search_processes: p.join() print(f'All {len(search_processes)} SearchProcesses stopped') scheduler.close() for p in llm_processes: p.join() print(f'All {len(llm_processes)} LLMProcesses stopped')