DeepSeek-Math/evaluation/run_subset_parallel.py
ZhihongShao 21cc5c6701 init
2024-02-06 10:27:40 +08:00

200 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
from tqdm import tqdm
from glob import glob
import time
import json
import subprocess
from utils import read_data
from data_processing.process_utils import *
_worker_num = int(os.environ.get('WORLD_SIZE', 1))
_worker_id = int(os.environ.get('RANK', 0))
def markup_question(args, item, language, src, task):
for i in range(len(item['messages']) - 2, -1, -2):
if language == 'zh':
if task == 'cot':
item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请通过逐步推理来解答问题,并把最终答案放置于" + "\\boxed{}中。"
elif task == 'tool':
item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请结合自然语言和Python程序语言来解答问题并把最终答案放置于" + "\\boxed{}中。"
else:
pass
elif language == 'en':
if task == 'cot':
item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease reason step by step, and put your final answer within " + "\\boxed{}."
elif task == 'tool':
item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within " + "\\boxed{}."
else:
pass
return item
def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, output_dir, log_dir):
if task == 'pal':
code_fname = "run_pal_eval"
elif task == 'cot':
code_fname = "run_cot_eval"
elif task == 'tool':
code_fname = "run_tool_integrated_eval"
else:
raise NotImplementedError()
n_procs = args.ngpus // args.ngpus_per_model
gpus = [str(i) for i in range(args.ngpus)]
gpu_groups = []
for i in range(n_procs):
gpu_groups.append(gpus[i * args.ngpus_per_model: (i + 1) * args.ngpus_per_model])
global_n_procs = n_procs * _worker_num
procs = []
for pid, gpus in enumerate(gpu_groups):
global_pid = n_procs * (args.rank or _worker_id) + pid
logpath = os.path.join(log_dir, f"{global_pid}.log")
f = open(logpath, "w")
cmd = f"python infer/{code_fname}.py " \
f"--data_dir {input_dir} " \
f"--max_num_examples 100000000000000 " \
f"--save_dir {output_dir} " \
f"--model {args.model_path} " \
f"--tokenizer {args.tokenizer_path or args.model_path} " \
f"--eval_batch_size 1 " \
f"--temperature {args.temperature} " \
f"--repeat_id_start 0 " \
f"--n_repeat_sampling {args.n_repeats} " \
f"--n_subsets {global_n_procs} " \
f"--prompt_format {args.prompt_format} " \
f"--few_shot_prompt {args.few_shot_prompt} " \
f"--answer_extraction_fn {answer_extraction_fn} " \
f"--eval_fn {eval_fn} " \
f"--subset_id {global_pid} " \
f"--gpus {','.join(gpus)} "
if args.use_vllm:
cmd += " --use_vllm "
if args.load_in_half:
cmd += " --load_in_half "
local_metric_path = os.path.join(output_dir, f"metrics.{global_pid}.json")
if not args.overwrite and os.path.exists(local_metric_path) and read_data(local_metric_path)['n_samples'] > 0:
continue
procs.append((global_pid, subprocess.Popen(cmd.split(), stdout=f, stderr=f), f))
for (global_pid, proc, f) in procs:
print(f"Waiting for the {global_pid}th process to finish ...", flush=True)
proc.wait()
for (global_pid, proc, f) in procs:
print(f"Closing the {global_pid}th process ...", flush=True)
f.close()
time.sleep(1)
local_pids = [global_pid for (global_pid, _, _) in procs]
agg_preds = []
for fname in glob(os.path.join(output_dir, "predictions.*.json")):
if any(str(pid) in fname for pid in local_pids):
agg_preds.extend(read_data(fname))
metrics = {}
n_samples = 0
for fname in glob(os.path.join(output_dir, "metrics.*.json")):
if not any(str(pid) in fname for pid in local_pids):
continue
_metrics = read_data(fname)
n_samples += _metrics['n_samples']
for key, val in _metrics.items():
if key != 'n_samples':
metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
for key, val in metrics.items():
metrics[key] = val / max(n_samples, 1)
result_msg = f"n samples = {n_samples}"
for key, val in metrics.items():
result_msg += f"\n{key} = {val * 100}"
metrics['n_samples'] = n_samples
return metrics, agg_preds, result_msg
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str, required=True, help="default to `model_path`_predictions")
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--tokenizer-path", type=str, default=None)
parser.add_argument("--model-size", type=str, choices=['1b', '7b', '13b', '33b', '34b', '70b'], default="7b")
parser.add_argument("--test-conf", type=str, default="configs/zero_shot_test_configs.json", help="path to testing data config file that maps from a source to its info")
parser.add_argument("--ngpus", type=int, default=8)
parser.add_argument("--overwrite", action='store_true')
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--n-repeats", type=int, default=1)
parser.add_argument("--use-vllm", action='store_true')
parser.add_argument("--load_in_half", action='store_true')
parser.add_argument("--prompt_format", type=str, default="sft")
parser.add_argument("--few_shot_prompt", type=str, default=None)
parser.add_argument("--no-markup-question", action='store_true')
parser.add_argument("--rank", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
args, _ = parser.parse_known_args()
print(f"Evaluating {args.model_path}", flush=True)
if args.output_dir is None:
args.output_dir = f"{args.model_path.rstrip('/')}_predictions"
args.ngpus_per_model = 4 if args.model_size in ['70b', '33b', '34b'] else 1
assert args.ngpus % args.ngpus_per_model == 0
default_few_shot_prompt = args.few_shot_prompt
test_conf = read_data(args.test_conf)
for src, info in test_conf.items():
if args.n_repeats > 1:
_src = f"{src}/sample_logs"
else:
_src = f"{src}/infer_logs"
if _worker_num > 1:
_src = f"{_src}/{args.rank or _worker_id}"
if args.prompt_format == 'few_shot':
args.few_shot_prompt = info.get('few_shot_prompt', None) or default_few_shot_prompt
for task in info['tasks']:
fname = os.path.join(args.output_dir, _src, task, "test_data", "test.jsonl")
input_dir = os.path.dirname(fname)
os.makedirs(input_dir, exist_ok=True)
metric_path = os.path.join(args.output_dir, _src, task, "samples", "metrics.json")
if not args.overwrite and os.path.exists(metric_path) and read_data(metric_path)['n_samples'] > 0:
continue
with open(fname, "w") as file:
data = read_data(info['test_path'])
for i, sample in enumerate(tqdm(data, desc=f'processing {src}')):
fn = eval(info['process_fn'])
sample['id'] = sample.get('id', f"{src}-{i}")
for j, item in enumerate(fn(sample)):
item['dataset'] = src
item['id'] = f"{src}-test-{i}-{j}"
assert 'answer' in item
if not args.no_markup_question:
item = markup_question(args, item, info['language'], src, task)
print(json.dumps(item), file=file, flush=True)
output_dir = os.path.join(args.output_dir, _src, task, "samples")
log_dir = os.path.join(args.output_dir, _src, task, "logs")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
metrics, agg_preds, result_msg = do_parallel_sampling(args, task, info['answer_extraction_fn'], info['eval_fn'], input_dir, output_dir, log_dir)
os.makedirs(os.path.dirname(metric_path), exist_ok=True)
json.dump(metrics, open(metric_path, "w"), indent=4)
data_path = os.path.join(args.output_dir, _src, task, "samples", "predictions.json")
os.makedirs(os.path.dirname(data_path), exist_ok=True)
with open(data_path, "w") as file:
json.dump(agg_preds, file, ensure_ascii=False)
print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True)
if __name__ == '__main__':
main()