mirror of
https://github.com/deepseek-ai/DeepSeek-Prover-V1.5
synced 2024-11-24 13:08:07 +00:00
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
import os
|
|
import copy
|
|
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
|
|
from prover.utils import load_jsonl_objects
|
|
|
|
|
|
class DataLoader(object):
|
|
def __init__(self, data_path, data_split, data_repeat, node_rank, world_size, log_dir):
|
|
self.manager = mp.Manager()
|
|
self.queue = self.manager.Queue()
|
|
self.lock = mp.Lock()
|
|
self.finished_flag_filename = 'finished_running.txt'
|
|
|
|
done_set = set()
|
|
for dirname in os.listdir(log_dir):
|
|
run_dir = os.path.join(log_dir, dirname)
|
|
if os.path.isdir(run_dir):
|
|
for subdirname in os.listdir(run_dir):
|
|
if subdirname.startswith('run') and os.path.exists(os.path.join(run_dir, subdirname, self.finished_flag_filename)):
|
|
done_set.add(os.path.join(dirname, subdirname))
|
|
|
|
todo_count = 0
|
|
if isinstance(data_split, str):
|
|
data_split = [data_split]
|
|
dataset = load_jsonl_objects(data_path)
|
|
for _repeat in range(data_repeat):
|
|
for prob_idx, prob in enumerate(dataset):
|
|
prob_runname = os.path.join(prob['name'], f'run{_repeat}')
|
|
if f'{prob_idx}_{prob_runname}' in done_set:
|
|
continue
|
|
if data_split is not None and prob['split'] not in data_split:
|
|
continue
|
|
if todo_count % world_size == node_rank:
|
|
self.queue.put((prob_idx, prob_runname, copy.deepcopy(prob)))
|
|
todo_count += 1
|
|
print('Number of TODO Problems: {}'.format(self.queue.qsize()))
|
|
|
|
def size(self):
|
|
return self.queue.qsize()
|
|
|
|
def get(self):
|
|
with self.lock:
|
|
if self.queue.qsize() > 0:
|
|
return self.queue.get()
|
|
return None, None, None
|