update eval and readme

This commit is contained in:
ZihanWang314 2024-08-09 18:06:57 +08:00
parent 809d0e377e
commit 36c6194349
10 changed files with 333 additions and 152 deletions

View File

@ -10,6 +10,9 @@ Y. Wu.
**ESFT** aims to efficiently customize Large Language Models (LLMs) with Mixture-of-Experts (MoE) architecture by adjusting only task-relevant parts, improving efficiency and performance while using fewer resources and storage.
## 📰 News
📅 **2024.8.11:** We now release the **ESFT training code**! ✨ You can now try it with your own models and dataset!
## 🚀 Quick Start
@ -19,9 +22,9 @@ git clone https://github.com/deepseek-ai/ESFT.git
cd esft
```
### Install dependencies
### Install required dependencies
```bash
pip install transformers torch safetensors
pip install transformers torch safetensors accelerate
```
### Download necessary adapters
@ -32,35 +35,38 @@ bash scripts/download_adapters.sh
## 🔧Key Scripts
1. **eval.py**
This script evaluates the performance of the model on various datasets. **Usage:**
1. **eval_multigpu.py**
This script evaluates the performance of the model on various datasets. See **scripts/eval.sh** for detailed configs and explanations.
**Usage:**
```bash
python scripts/eval.py \
--eval_datasets=translation \
python eval_multigpu.py \
--eval_dataset=translation \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--adapter_dir=all_models/adapters/token \
--output_dir=results/completions/token \
--max_new_tokens=512 \
--openai_api_key=REPLACE_WITH_YOUR_KEY \
--eval_batch_size=2
--adapter_dir=all_models/adapters/token/translation \
--output_path=results/completions/token/translation.jsonl \
--openai_api_key=YOUR_OPENAI_API_KEY
```
2. **get_expert_scores.py**
This script calculates the scores for each expert based on the evaluation datasets.
**Usage:**
```bash
python scripts/get_expert_scores.py \
--eval_datasets=intent,summary,law,translation \
python scripts/expert/get_expert_scores.py \
--eval_dataset=translation \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--output_dir=results/expert_scores \
--n_sample_tokens=8192 # the sample size hyperparameter
--output_dir=results/expert_scores/translation \
--n_sample_tokens=131072 \
--world_size=4 \
--gpus_per_rank=2
```
3. **generate_expert_config.py**
This script generates the configuration to convert a MoE model with only task-relevant tasks trained based on evaluation scores.
**Usage:**
```bash
python scripts/generate_expert_config.py \
python scripts/expert/generate_expert_config.py \
--eval_datasets=intent,summary,law,translation \
--expert_scores_dir=results/expert_scores \
--output_dir=results/expert_configs \
@ -68,13 +74,32 @@ python scripts/generate_expert_config.py \
--top_p=0.2 # the scoring function and top_p are hyperparameters
```
4. **train.py** and **train_ep.py**
This script trains the model with the expert configuration generated by the previous script. The train_ep.py file uses expert parallel and has been optimized for multi-GPU training.
**Usage:**
```bash
python train.py \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--expert_config=results/expert_configs/intent.json \
--train_dataset=intent \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/intent
torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--expert_config=results/expert_configs/translation.json \
--train_dataset=translation \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/translation
```
## Contact and Support
For bug reports, feature requests, and general inquiries, please open an issue on our GitHub Issues page. Make sure to include as much detail as possible to help us address your issue quickly.
## 🌟Todo list
- ☑️ 📝 Update models, evaluation scripts, and expert selection scripts
- 🔲 🔧 Update training scripts
- ☑️ 🔧 Update training scripts
- 🔲 🚀 More...

92
eval_multigpu.py Normal file
View File

@ -0,0 +1,92 @@
import json
import argparse
from torch import device
from benchmarks import *
import os
from esft import load_base_model, add_adapter
import torch.multiprocessing as mp
from itertools import accumulate
from accelerate import dispatch_model
from transformers import AutoModelForCausalLM, AutoTokenizer
def infer_auto_device_map(model, pp_splits, visible_devices):
assert len(pp_splits) == len(visible_devices)
device_map = {
"model.embed_tokens": 0,
"model.norm": len(pp_splits) - 1,
"lm_head": len(pp_splits) - 1
}
assert len(model.model.layers) == sum(pp_splits)
pp_splits = [0, *list(accumulate(pp_splits))]
for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])):
for i in range(start, end):
device_map.update({f"model.layers.{i}": idx})
for k, v in device_map.items():
device_map[k] = visible_devices[v]
return device_map
def eval_model(rank, args, model, dataset):
config = {
"max_new_tokens": args.max_new_tokens,
"eval_batch_size": args.eval_batch_size,
"openai_api_key": args.openai_api_key
}
evaluator_map = {
"intent": IntentEvaluator,
"summary": SummaryEvaluator,
"law": LawEvaluator,
"translation": TranslationEvaluator
}
try:
evaluator_cls = evaluator_map[args.eval_dataset]
print(f"Rank {rank} starting evaluation...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank))
device_map = infer_auto_device_map(model, [14, 13], visible_devices)
model = dispatch_model(model, device_map)
cur_dataset = dataset[rank::args.world_size]
evaluator = evaluator_cls(cur_dataset, config)
with torch.no_grad():
results, metrics = evaluator.evaluate(model, tokenizer)
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path + f".rank_{rank}", "w") as f:
for res, m in zip(results, metrics):
obj = {
"example": res,
"score": m
}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
except Exception as e:
print(f"Error in process {rank}: {e}", flush=True)
raise
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.")
parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset")
parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model")
parser.add_argument("--adapter_dir", type=str, required=True, help="Directory containing the adapter")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the evaluation results")
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens")
parser.add_argument("--openai_api_key", type=str, required=True, help="API key for OpenAI")
parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation")
parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation")
parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process")
args = parser.parse_args()
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock
print(f"Running evaluation on {args.eval_dataset}...")
dataset = [json.loads(i) for i in open(f"datasets/eval/{args.eval_dataset}.jsonl").readlines()]
print("Adding adapter...")
model = add_adapter(model, args.adapter_dir, return_original_states=False)
print("Start Evaluating...")
mp.spawn(eval_model, args=(args, model, dataset), nprocs=args.world_size, join=True)

View File

@ -1,12 +1,24 @@
# first, download adapter models and put them to the corresponding directories
# first: download adapter models and put them to the corresponding directories
python scripts/eval.py \
python eval_multigpu.py \
--eval_datasets=translation \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--adapter_dir=all_models/adapters/token \
--output_dir=results/completions/token \
--max_new_tokens=512 \
--openai_api_key=REPLACE_WITH_YOUR_KEY \
--eval_batch_size=2
--eval_batch_size=2 \
--world_size=4 \
--gpus_per_rank=2
# this script is used for single-gpu training and has been deprecated. If you have no multiple gpus, you can set above world_size=1 and gpus_per_rank=1
# python scripts/eval.py \
# --eval_datasets=translation \
# --base_model_path=deepseek-ai/ESFT-vanilla-lite \
# --adapter_dir=all_models/adapters/token \
# --output_dir=results/completions/token \
# --max_new_tokens=512 \
# --openai_api_key=REPLACE_WITH_YOUR_KEY \
# --eval_batch_size=2

View File

@ -1,10 +1,12 @@
python scripts/get_expert_scores.py \
--eval_datasets=intent,summary,law,translation \
python scripts/expert/get_expert_scores.py \
--eval_dataset=translation \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--output_dir=results/expert_scores \
--n_sample_tokens=8192 # this sample size is a hyperparameter
--output_dir=results/expert_scores/translation \
--n_sample_tokens=131072 \
--world_size=4 \
--gpus_per_rank=2
python scripts/generate_expert_config.py \
python scripts/expert/generate_expert_config.py \
--eval_datasets=intent,summary,law,translation \
--expert_scores_dir=results/expert_scores \
--output_dir=results/expert_configs \

View File

@ -0,0 +1,97 @@
import argparse
import json
import os
from multiprocessing import Pool
import numpy as np
def parse_line(line):
expert_ids, expert_weights = line.split("\t\t")
expert_ids = [int(i) for i in expert_ids.split("\t")]
expert_weights = [float(i) for i in expert_weights.split("\t")]
return expert_ids, expert_weights
def get_summary(files):
TOP_K=6
N_EXPERTS=64
N_LAYERS=26 # 27 layers totally, the first layer is not MoE
gate_scores = np.zeros((N_LAYERS, N_EXPERTS))
token_scores = np.zeros((N_LAYERS, N_EXPERTS))
print("loading files")
for rank, file in files:
layer_id = int(file.split(".")[0].split("_")[2]) - 1
with open(os.path.join(args.expert_scores_dir, rank, file)) as f:
data = f.readlines()
for line in data:
expert_ids, expert_weights = parse_line(line)
np.add.at(gate_scores[layer_id], expert_ids, expert_weights)
np.add.at(token_scores[layer_id], expert_ids, np.ones_like(expert_weights) / TOP_K)
gate_scores = gate_scores / np.sum(gate_scores, axis=0)
token_scores = token_scores / np.sum(token_scores, axis=0)
summary = {"token_scores": token_scores, "gate_scores": gate_scores}
summary = {k: {str(i+1): {str(j): round(v, 4) for j, v in enumerate(l)} for i, l in enumerate(v)} for k, v in summary.items()}
return summary
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--eval_dataset", type=str, required=True)
parser.add_argument("--expert_scores_dir", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--score_function", type=str, required=True)
parser.add_argument("--top_p", type=float, required=True)
parser.add_argument("--train_shared_experts", action="store_true")
parser.add_argument("--train_non_expert_modules", action="store_true")
args = parser.parse_args()
expert_cfg = { # initialize expert config
"experts": {},
"shared_experts": args.train_shared_experts,
"non_expert_modules": args.train_non_expert_modules
}
# let's walk inside args.expert_scores_dir and get abs file names
file_names = []
for rank in [i for i in os.listdir(args.expert_scores_dir) if 'rank' in i]:
for file in os.listdir(os.path.join(args.expert_scores_dir, rank)):
file_names.append([rank, file])
summary_file = os.path.join(args.expert_scores_dir, "summary.json")
summary = get_summary(file_names)
with open(summary_file, "w") as f:
f.write(json.dumps(summary))
scores = summary[f"{args.score_function}_scores"]
for layer, l_score in scores.items():
l_score = [(int(k), v) for k,v in l_score.items()]
l_score = sorted(l_score, key=lambda x: x[1], reverse=True)
selected_experts = []
current_score = 0
for expert, score in l_score:
if current_score >= args.top_p:
break
selected_experts.append(expert)
current_score += score
expert_cfg["experts"][layer] = selected_experts
top_p = args.top_p
train_shared_experts = args.train_shared_experts
train_non_expert_modules = args.train_non_expert_modules
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w") as f:
json.dump(expert_cfg, f)

View File

@ -0,0 +1,78 @@
import json
import os
import torch
import argparse
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_formatted_input_and_target
import torch.multiprocessing as mp
from itertools import accumulate
from accelerate import dispatch_model
def infer_auto_device_map(model, pp_splits, visible_devices):
assert len(pp_splits) == len(visible_devices)
device_map = {
"model.embed_tokens": 0,
"model.norm": len(pp_splits) - 1,
"lm_head": len(pp_splits) - 1
}
assert len(model.model.layers) == sum(pp_splits)
pp_splits = [0, *list(accumulate(pp_splits))]
for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])):
for i in range(start, end):
device_map.update({f"model.layers.{i}": idx})
for k, v in device_map.items():
device_map[k] = visible_devices[v]
return device_map
def eval_expert(rank, args, model, dataset):
try:
print(f"Rank {rank} starting expert evaluation...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank))
device_map = infer_auto_device_map(model, [14, 13], visible_devices)
model = dispatch_model(model, device_map)
model.config.expert_log_dir = os.path.join(args.output_dir, f"rank_{rank}")
n_sample_tokens = args.n_sample_tokens // args.world_size
os.makedirs(os.path.join(args.output_dir, f"rank_{rank}"), exist_ok=True)
done_tokens = 0
cur_dataset = dataset[rank::args.world_size]
for instance in cur_dataset:
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0))
done_tokens += len(input_ids)
if done_tokens >= n_sample_tokens:
break
except Exception as e:
print(f"Error in process {rank}: {e}", flush=True)
raise
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.")
parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset")
parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model")
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the evaluation results")
parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation")
parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process")
parser.add_argument("--n_sample_tokens", type=int, required=True, help="Token to sample for expert evaluation")
args = parser.parse_args()
random.seed(5934875)
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock
model.config.log_expert_weights = True
print(f"Running expert evaluation on {args.eval_dataset}...")
dataset = [json.loads(i) for i in open(f"datasets/train/{args.eval_dataset}.jsonl").readlines()]
random.shuffle(dataset)
print("Start Evaluating...")
mp.spawn(eval_expert, args=(args, model, dataset), nprocs=args.world_size, join=True)

View File

@ -1,50 +0,0 @@
import argparse
import json
import os
parser = argparse.ArgumentParser()
parser.add_argument("--eval_datasets", type=str, required=True)
parser.add_argument("--expert_scores_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--score_function", type=str, required=True)
parser.add_argument("--top_p", type=float, required=True)
parser.add_argument("--train_shared_experts", action="store_true")
parser.add_argument("--train_non_expert_modules", action="store_true")
args = parser.parse_args()
eval_datasets = args.eval_datasets.split(",")
expert_scores_dir = args.expert_scores_dir
output_dir = args.output_dir
score_function = args.score_function
top_p = args.top_p
train_shared_experts = args.train_shared_experts
train_non_expert_modules = args.train_non_expert_modules
for dataset_name in eval_datasets:
summary_file = f"{expert_scores_dir}/{dataset_name}/summary.json"
expert_cfg = {"experts": {}, "shared_experts": train_shared_experts, "non_expert_modules": train_non_expert_modules}
with open(summary_file) as f:
data = json.load(f)
assert score_function in ["gate", "token"], f"Unknown score function: {score_function}"
scores = data[f"{score_function}_scores"]
for layer, l_score in scores.items():
l_score = [(int(k), v) for k,v in l_score.items()]
l_score = sorted(l_score, key=lambda x: x[1], reverse=True)
# get the top experts that make the threshold exceed top_p
selected_experts = []
current_score = 0
for expert, score in l_score:
if current_score >= top_p:
break
selected_experts.append(expert)
current_score += score
expert_cfg["experts"][layer] = selected_experts
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/{dataset_name}.json", "w") as f:
json.dump(expert_cfg, f)

View File

@ -1,75 +0,0 @@
import json
from benchmarks import *
import os
import torch
from torch import nn
import argparse
from random import shuffle
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_formatted_input_and_target
# constants for deepseek-v2-lite
TOP_K=6
N_EXPERTS=64
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--eval_datasets", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--n_sample_tokens", type=int, required=True)
args = parser.parse_args()
eval_datasets = args.eval_datasets.split(",")
output_dir = args.output_dir
base_model_path = args.base_model_path
n_sample_tokens = args.n_sample_tokens
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path)
model.config.log_expert_weights = True
for dataset_name in eval_datasets:
dataset = [json.loads(i) for i in open(f"datasets/train/{dataset_name}.jsonl").readlines()]
shuffle(dataset)
model.config.expert_log_dir = os.path.join(args.output_dir, dataset_name)
# make dir -p this
os.makedirs(os.path.join(args.output_dir, dataset_name), exist_ok=True)
done_tokens = 0
for instance in dataset:
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0))
done_tokens += len(input_ids)
if done_tokens >= n_sample_tokens:
break
# open all files under os.path.join(args.output_dir, dataset_name). For each file, generate a summary of it
# and write it to a file in the same directory
files = os.listdir(os.path.join(args.output_dir, dataset_name))
summary_file = os.path.join(args.output_dir, dataset_name, "summary.json")
token_scores = {}
gate_scores = {}
for file in files:
if not file.endswith(".txt"):
continue
layer_idx = file.split("_")[2].split(".")[0]
token_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)}
gate_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)}
with open(os.path.join(args.output_dir, dataset_name, file)) as f:
data = f.readlines()
for line in data:
expert_ids, expert_weights = line.split("\t\t")
expert_ids = [int(i) for i in expert_ids.split("\t")]
expert_weights = [float(i) for i in expert_weights.split("\t")]
for expert_id, expert_weight in zip(expert_ids, expert_weights):
gate_scores[layer_idx][expert_id] += expert_weight
token_scores[layer_idx][expert_id] += 1. / TOP_K
total = sum(token_scores[layer_idx].values())
gate_scores[layer_idx] = {expert: round(gate_scores[layer_idx][expert] / total, 4) for expert in gate_scores[layer_idx]}
token_scores[layer_idx] = {expert: round(token_scores[layer_idx][expert] / total, 4) for expert in token_scores[layer_idx]}
with open(summary_file, "w") as f:
f.write(json.dumps({"token_scores": token_scores, "gate_scores": gate_scores}))

View File

@ -2,7 +2,7 @@
export TOKENIZERS_PARALLELISM=false
exp_name="test/eval_translation"
base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model"
base_model_path="deepseek-ai/esft-vanilla-lite"
# turn above to for loop
python train.py \
--base_model_path=${base_model_path} \

View File

@ -2,7 +2,7 @@
export TOKENIZERS_PARALLELISM=false
exp_name="test/eval_translation"
base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model"
base_model_path="deepseek-ai/esft-vanilla-lite"
torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=${base_model_path} \
--expert_config=results/expert_configs/translation.json \