mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-25 21:27:57 +00:00
117 lines
4.4 KiB
Python
117 lines
4.4 KiB
Python
|
import argparse
|
||
|
import json
|
||
|
import yaml
|
||
|
import os
|
||
|
import random
|
||
|
import torch
|
||
|
from torch.utils.data import TensorDataset
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging
|
||
|
|
||
|
from benchmarks import *
|
||
|
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad
|
||
|
from esft import to_esft
|
||
|
from deepseek.modeling_deepseek import DeepseekV2ForCausalLM
|
||
|
|
||
|
|
||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
parser.add_argument("--base_model_path", type=str, required=True)
|
||
|
parser.add_argument("--expert_config", type=str, required=True)
|
||
|
parser.add_argument("--train_dataset", type=str, required=True)
|
||
|
parser.add_argument("--output_dir", type=str, required=True)
|
||
|
parser.add_argument("--train_config", type=str, required=True)
|
||
|
parser.add_argument("--wandb_api_key", type=str, required=False)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
expert_config = json.load(open(args.expert_config))
|
||
|
output_dir = args.output_dir
|
||
|
base_model_path = args.base_model_path
|
||
|
config = yaml.safe_load(open(args.train_config))
|
||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
|
||
|
seed = config['seed']
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
random.seed(seed)
|
||
|
|
||
|
if args.wandb_api_key is not None:
|
||
|
import wandb
|
||
|
wandb.login(key=args.wandb_api_key)
|
||
|
|
||
|
# Prepare data
|
||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
||
|
samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()]
|
||
|
buffer = []
|
||
|
for instance in samples:
|
||
|
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
|
||
|
buffer.append((input_ids, target_ids))
|
||
|
seq_length = config['seq_length']
|
||
|
random_concat_ratio = config['random_concat_ratio']
|
||
|
concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
|
||
|
|
||
|
dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
|
||
|
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])
|
||
|
|
||
|
# Training arguments
|
||
|
training_args = TrainingArguments(
|
||
|
output_dir=output_dir,
|
||
|
max_steps=config['steps'],
|
||
|
per_device_train_batch_size=config['per_device_batch_size'],
|
||
|
per_device_eval_batch_size=config['per_device_batch_size'],
|
||
|
warmup_steps=config['warmup_steps'],
|
||
|
weight_decay=config['weight_decay'],
|
||
|
logging_dir=f"{output_dir}/logs",
|
||
|
logging_steps=config['logging_steps'],
|
||
|
save_steps=config['save_steps'],
|
||
|
eval_strategy="steps",
|
||
|
eval_steps=config['eval_steps'],
|
||
|
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
||
|
load_best_model_at_end=True,
|
||
|
metric_for_best_model="loss",
|
||
|
greater_is_better=False,
|
||
|
bf16=True,
|
||
|
lr_scheduler_type='constant',
|
||
|
save_total_limit=5,
|
||
|
learning_rate=config['learning_rate'],
|
||
|
optim=config['optim'],
|
||
|
adam_beta1=config['adam_beta1'],
|
||
|
adam_beta2=config['adam_beta2'],
|
||
|
gradient_checkpointing=config['gradient_checkpointing'],
|
||
|
gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug
|
||
|
)
|
||
|
|
||
|
def data_collator(data):
|
||
|
input_ids = torch.stack([item[0] for item in data])
|
||
|
labels = torch.stack([item[1] for item in data])
|
||
|
return {"input_ids": input_ids, "labels": labels}
|
||
|
|
||
|
|
||
|
model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
|
||
|
to_esft(model, expert_config)
|
||
|
|
||
|
# Initialize Trainer
|
||
|
trainer = Trainer(
|
||
|
model=model,
|
||
|
args=training_args,
|
||
|
train_dataset=train_dataset,
|
||
|
eval_dataset=valid_dataset,
|
||
|
data_collator=data_collator,
|
||
|
)
|
||
|
# Training
|
||
|
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already
|
||
|
trainer.train(resume_from_checkpoint=True)
|
||
|
else:
|
||
|
trainer.train()
|
||
|
|
||
|
# Save the model and tokenizer
|
||
|
trainer.save_model(output_dir + "/last_checkpoint")
|
||
|
tokenizer.save_pretrained(output_dir + "/last_checkpoint")
|
||
|
|
||
|
print("Training complete")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|