mirror of
https://github.com/deepseek-ai/DeepSeek-Math
synced 2024-11-22 03:27:40 +00:00
[fixed] the merging output is incorrect, when parallel_num=1
This commit is contained in:
parent
7c34ad4fa4
commit
ddf18bb444
@ -90,29 +90,35 @@ def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, o
|
|||||||
|
|
||||||
local_pids = [global_pid for (global_pid, _, _) in procs]
|
local_pids = [global_pid for (global_pid, _, _) in procs]
|
||||||
|
|
||||||
agg_preds = []
|
if global_n_procs == 1:
|
||||||
for fname in glob(os.path.join(output_dir, "predictions.*.json")):
|
agg_preds = read_data(os.path.join(output_dir, "predictions.json"))
|
||||||
if any(str(pid) in fname for pid in local_pids):
|
else:
|
||||||
agg_preds.extend(read_data(fname))
|
agg_preds = []
|
||||||
|
for fname in glob(os.path.join(output_dir, "predictions.*.json")):
|
||||||
|
if any(str(pid) in fname for pid in local_pids):
|
||||||
|
agg_preds.extend(read_data(fname))
|
||||||
|
if global_n_procs == 1:
|
||||||
|
metrics = read_data(os.path.join(output_dir, "metrics.json"))
|
||||||
|
result_msg = f"n samples = {metrics['n_samples']}"
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
n_samples = 0
|
||||||
|
for fname in glob(os.path.join(output_dir, "metrics.*.json")):
|
||||||
|
if not any(str(pid) in fname for pid in local_pids):
|
||||||
|
continue
|
||||||
|
_metrics = read_data(fname)
|
||||||
|
n_samples += _metrics['n_samples']
|
||||||
|
for key, val in _metrics.items():
|
||||||
|
if key != 'n_samples':
|
||||||
|
metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
|
||||||
|
for key, val in metrics.items():
|
||||||
|
metrics[key] = val / max(n_samples, 1)
|
||||||
|
|
||||||
metrics = {}
|
result_msg = f"n samples = {n_samples}"
|
||||||
n_samples = 0
|
for key, val in metrics.items():
|
||||||
for fname in glob(os.path.join(output_dir, "metrics.*.json")):
|
result_msg += f"\n{key} = {val * 100}"
|
||||||
if not any(str(pid) in fname for pid in local_pids):
|
|
||||||
continue
|
|
||||||
_metrics = read_data(fname)
|
|
||||||
n_samples += _metrics['n_samples']
|
|
||||||
for key, val in _metrics.items():
|
|
||||||
if key != 'n_samples':
|
|
||||||
metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
|
|
||||||
for key, val in metrics.items():
|
|
||||||
metrics[key] = val / max(n_samples, 1)
|
|
||||||
|
|
||||||
result_msg = f"n samples = {n_samples}"
|
metrics['n_samples'] = n_samples
|
||||||
for key, val in metrics.items():
|
|
||||||
result_msg += f"\n{key} = {val * 100}"
|
|
||||||
|
|
||||||
metrics['n_samples'] = n_samples
|
|
||||||
|
|
||||||
return metrics, agg_preds, result_msg
|
return metrics, agg_preds, result_msg
|
||||||
|
|
||||||
@ -196,4 +202,4 @@ def main():
|
|||||||
print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True)
|
print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
Loading…
Reference in New Issue
Block a user