mirror of
https://github.com/deepseek-ai/ESFT
synced 2025-06-26 18:15:50 +00:00
add training code
This commit is contained in:
154
train_ep.py
Normal file
154
train_ep.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import argparse
|
||||
import json
|
||||
import yaml
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from types import MethodType
|
||||
from torch.utils.data import TensorDataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging
|
||||
|
||||
from benchmarks import *
|
||||
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad, init_parallel_groups
|
||||
from esft import to_esft
|
||||
from deepseek.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--base_model_path", type=str, required=True)
|
||||
parser.add_argument("--expert_config", type=str, required=True)
|
||||
parser.add_argument("--train_dataset", type=str, required=True)
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument("--train_config", type=str, required=True)
|
||||
parser.add_argument("--wandb_api_key", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
expert_config = json.load(open(args.expert_config))
|
||||
output_dir = args.output_dir
|
||||
base_model_path = args.base_model_path
|
||||
config = yaml.safe_load(open(args.train_config))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
seed = config['seed']
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
if args.wandb_api_key is not None:
|
||||
import wandb
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
ep_size = config.get("ep_size", 1)
|
||||
world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size)
|
||||
edp_size = world_size // ep_size
|
||||
|
||||
# Prepare data
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
||||
samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()]
|
||||
buffer = []
|
||||
for instance in samples:
|
||||
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
|
||||
buffer.append((input_ids, target_ids))
|
||||
seq_length = config['seq_length']
|
||||
random_concat_ratio = config['random_concat_ratio']
|
||||
concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
|
||||
|
||||
dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
|
||||
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_dir,
|
||||
max_steps=config['steps'],
|
||||
per_device_train_batch_size=config['per_device_batch_size'],
|
||||
per_device_eval_batch_size=config['per_device_batch_size'],
|
||||
warmup_steps=config['warmup_steps'],
|
||||
weight_decay=config['weight_decay'],
|
||||
logging_dir=f"{output_dir}/logs",
|
||||
logging_steps=config['logging_steps'],
|
||||
save_steps=config['save_steps'],
|
||||
eval_strategy="steps",
|
||||
eval_steps=config['eval_steps'],
|
||||
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="loss",
|
||||
greater_is_better=False,
|
||||
bf16=True,
|
||||
lr_scheduler_type='constant',
|
||||
save_total_limit=5,
|
||||
learning_rate=config['learning_rate'],
|
||||
optim=config['optim'],
|
||||
adam_beta1=config['adam_beta1'],
|
||||
adam_beta2=config['adam_beta2'],
|
||||
disable_tqdm=False,
|
||||
gradient_checkpointing=config['gradient_checkpointing'],
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug
|
||||
)
|
||||
|
||||
def data_collator(data):
|
||||
input_ids = torch.stack([item[0] for item in data])
|
||||
labels = torch.stack([item[1] for item in data])
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
|
||||
model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16,
|
||||
ep_size=ep_size, attn_implementation="flash_attention_2")
|
||||
model._ddp_params_and_buffers_to_ignore = [n for n, _ in model.named_parameters() if ".expert" in n] # we manage grad synchronization of expert parameters
|
||||
to_esft(model, expert_config)
|
||||
model.dummy = torch.nn.Parameter(torch.zeros(1, dtype=model.dtype)) # prevent DDP from having no trainable parameters
|
||||
model._keys_to_ignore_on_save = ["dummy"]
|
||||
expert_params = [p for n, p in model.named_parameters() if p.requires_grad and ".expert" in n]
|
||||
for layer in model.model.layers:
|
||||
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
||||
continue
|
||||
layer.mlp.ep_group = ep_group
|
||||
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=valid_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
accelerator = trainer.accelerator
|
||||
backward = accelerator.backward
|
||||
def custom_backward(self, loss, **kwargs):
|
||||
backward(loss, **kwargs)
|
||||
if not self.sync_gradients or edp_size == 1:
|
||||
return
|
||||
return
|
||||
for p in expert_params:
|
||||
g = p.grad if p.grad is not None else torch.zeros_like(p)
|
||||
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
|
||||
if p.grad is not g:
|
||||
p.grad = g
|
||||
accelerator.backward = MethodType(custom_backward, accelerator)
|
||||
|
||||
# Training
|
||||
ckpt_path = f"{output_dir}/last_checkpoint_ep{local_rank}"
|
||||
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already
|
||||
trainer.train(resume_from_checkpoint=ckpt_path)
|
||||
else:
|
||||
trainer.train()
|
||||
|
||||
# Save the model and tokenizer
|
||||
if local_rank == 0:
|
||||
trainer.save_model(ckpt_path)
|
||||
tokenizer.save_pretrained(ckpt_path)
|
||||
elif 0 < local_rank < ep_size:
|
||||
model.save_pretrained(ckpt_path)
|
||||
|
||||
print("Training complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user