[fixed] the merging output is incorrect, when parallel_num=1

This commit is contained in:
Dylancer1998 2024-04-11 09:30:32 +00:00
parent 7c34ad4fa4
commit ddf18bb444

View File

@ -90,29 +90,35 @@ def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, o
local_pids = [global_pid for (global_pid, _, _) in procs] local_pids = [global_pid for (global_pid, _, _) in procs]
agg_preds = [] if global_n_procs == 1:
for fname in glob(os.path.join(output_dir, "predictions.*.json")): agg_preds = read_data(os.path.join(output_dir, "predictions.json"))
if any(str(pid) in fname for pid in local_pids): else:
agg_preds.extend(read_data(fname)) agg_preds = []
for fname in glob(os.path.join(output_dir, "predictions.*.json")):
if any(str(pid) in fname for pid in local_pids):
agg_preds.extend(read_data(fname))
if global_n_procs == 1:
metrics = read_data(os.path.join(output_dir, "metrics.json"))
result_msg = f"n samples = {metrics['n_samples']}"
else:
metrics = {}
n_samples = 0
for fname in glob(os.path.join(output_dir, "metrics.*.json")):
if not any(str(pid) in fname for pid in local_pids):
continue
_metrics = read_data(fname)
n_samples += _metrics['n_samples']
for key, val in _metrics.items():
if key != 'n_samples':
metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
for key, val in metrics.items():
metrics[key] = val / max(n_samples, 1)
metrics = {} result_msg = f"n samples = {n_samples}"
n_samples = 0 for key, val in metrics.items():
for fname in glob(os.path.join(output_dir, "metrics.*.json")): result_msg += f"\n{key} = {val * 100}"
if not any(str(pid) in fname for pid in local_pids):
continue
_metrics = read_data(fname)
n_samples += _metrics['n_samples']
for key, val in _metrics.items():
if key != 'n_samples':
metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
for key, val in metrics.items():
metrics[key] = val / max(n_samples, 1)
result_msg = f"n samples = {n_samples}" metrics['n_samples'] = n_samples
for key, val in metrics.items():
result_msg += f"\n{key} = {val * 100}"
metrics['n_samples'] = n_samples
return metrics, agg_preds, result_msg return metrics, agg_preds, result_msg
@ -196,4 +202,4 @@ def main():
print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True) print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True)
if __name__ == '__main__': if __name__ == '__main__':
main() main()