From ddf18bb44455462f21f53cfa4dd6df1d17675afe Mon Sep 17 00:00:00 2001 From: Dylancer1998 Date: Thu, 11 Apr 2024 09:30:32 +0000 Subject: [PATCH] [fixed] the merging output is incorrect, when parallel_num=1 --- evaluation/run_subset_parallel.py | 50 +++++++++++++++++-------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/evaluation/run_subset_parallel.py b/evaluation/run_subset_parallel.py index da65019..a68c4eb 100644 --- a/evaluation/run_subset_parallel.py +++ b/evaluation/run_subset_parallel.py @@ -90,29 +90,35 @@ def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, o local_pids = [global_pid for (global_pid, _, _) in procs] - agg_preds = [] - for fname in glob(os.path.join(output_dir, "predictions.*.json")): - if any(str(pid) in fname for pid in local_pids): - agg_preds.extend(read_data(fname)) + if global_n_procs == 1: + agg_preds = read_data(os.path.join(output_dir, "predictions.json")) + else: + agg_preds = [] + for fname in glob(os.path.join(output_dir, "predictions.*.json")): + if any(str(pid) in fname for pid in local_pids): + agg_preds.extend(read_data(fname)) + if global_n_procs == 1: + metrics = read_data(os.path.join(output_dir, "metrics.json")) + result_msg = f"n samples = {metrics['n_samples']}" + else: + metrics = {} + n_samples = 0 + for fname in glob(os.path.join(output_dir, "metrics.*.json")): + if not any(str(pid) in fname for pid in local_pids): + continue + _metrics = read_data(fname) + n_samples += _metrics['n_samples'] + for key, val in _metrics.items(): + if key != 'n_samples': + metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples'] + for key, val in metrics.items(): + metrics[key] = val / max(n_samples, 1) - metrics = {} - n_samples = 0 - for fname in glob(os.path.join(output_dir, "metrics.*.json")): - if not any(str(pid) in fname for pid in local_pids): - continue - _metrics = read_data(fname) - n_samples += _metrics['n_samples'] - for key, val in _metrics.items(): - if key != 'n_samples': - metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples'] - for key, val in metrics.items(): - metrics[key] = val / max(n_samples, 1) + result_msg = f"n samples = {n_samples}" + for key, val in metrics.items(): + result_msg += f"\n{key} = {val * 100}" - result_msg = f"n samples = {n_samples}" - for key, val in metrics.items(): - result_msg += f"\n{key} = {val * 100}" - - metrics['n_samples'] = n_samples + metrics['n_samples'] = n_samples return metrics, agg_preds, result_msg @@ -196,4 +202,4 @@ def main(): print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True) if __name__ == '__main__': - main() + main() \ No newline at end of file