ESFT/scripts/generate_expert_config.py

51 lines
2.0 KiB
Python
Raw Normal View History

import argparse
import json
import os
parser = argparse.ArgumentParser()
parser.add_argument("--eval_datasets", type=str, required=True)
parser.add_argument("--expert_scores_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--score_function", type=str, required=True)
parser.add_argument("--top_p", type=float, required=True)
parser.add_argument("--train_shared_experts", action="store_true")
parser.add_argument("--train_non_expert_modules", action="store_true")
args = parser.parse_args()
eval_datasets = args.eval_datasets.split(",")
expert_scores_dir = args.expert_scores_dir
output_dir = args.output_dir
score_function = args.score_function
top_p = args.top_p
train_shared_experts = args.train_shared_experts
train_non_expert_modules = args.train_non_expert_modules
for dataset_name in eval_datasets:
summary_file = f"{expert_scores_dir}/{dataset_name}/summary.json"
expert_cfg = {"experts": {}, "shared_experts": train_shared_experts, "non_expert_modules": train_non_expert_modules}
with open(summary_file) as f:
data = json.load(f)
assert score_function in ["gate", "token"], f"Unknown score function: {score_function}"
scores = data[f"{score_function}_scores"]
for layer, l_score in scores.items():
l_score = [(int(k), v) for k,v in l_score.items()]
l_score = sorted(l_score, key=lambda x: x[1], reverse=True)
# get the top experts that make the threshold exceed top_p
selected_experts = []
current_score = 0
for expert, score in l_score:
if current_score >= top_p:
break
selected_experts.append(expert)
current_score += score
expert_cfg["experts"][layer] = selected_experts
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/{dataset_name}.json", "w") as f:
json.dump(expert_cfg, f)