mirror of
https://github.com/deepseek-ai/DeepSeek-MoE
synced 2025-01-22 10:35:57 +00:00
323 lines
13 KiB
Python
323 lines
13 KiB
Python
import copy
|
|
import random
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, Dict, Sequence
|
|
import logging
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import transformers
|
|
from transformers import Trainer, BitsAndBytesConfig
|
|
from datasets import load_dataset
|
|
import datasets
|
|
import numpy as np
|
|
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
|
from peft.tuners.lora import LoraLayer
|
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
|
|
IGNORE_INDEX = -100
|
|
EOT_TOKEN = "<|EOT|>"
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def build_instruction_prompt(instruction: str):
|
|
return '''
|
|
You are an AI assistant, developed by DeepSeek Company. For politically sensitive questions, security and privacy issues, you will refuse to answer.
|
|
### Instruction:
|
|
{}
|
|
### Response:
|
|
'''.format(instruction.strip()).lstrip()
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
trainable : Optional[str] = field(default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj")
|
|
lora_rank : Optional[int] = field(default=8)
|
|
lora_dropout : Optional[float] = field(default=0.1)
|
|
lora_alpha : Optional[float] = field(default=32.)
|
|
modules_to_save : Optional[str] = field(default="embed_tokens,lm_head")
|
|
use_lora : Optional[bool] = field(default=False)
|
|
model_name_or_path: Optional[str] = field(default="deepseek-ai/deepseek-moe-16b")
|
|
attn_implementation : Optional[str] = field(default="flash_attention_2")
|
|
double_quant: bool = field(
|
|
default=True,
|
|
metadata={"help": "Compress the quantization statistics through double quantization."}
|
|
)
|
|
quant_type: str = field(
|
|
default="nf4",
|
|
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
|
)
|
|
bits: int = field(
|
|
default=16,
|
|
metadata={"help": "How many bits to use."}
|
|
)
|
|
|
|
@dataclass
|
|
class DataArguments:
|
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
|
|
|
|
|
@dataclass
|
|
class TrainingArguments(transformers.TrainingArguments):
|
|
|
|
cache_dir: Optional[str] = field(default=None)
|
|
optim: str = field(default="adamw_torch")
|
|
model_max_length: int = field(
|
|
default=512,
|
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
|
)
|
|
|
|
class SavePeftModelCallback(transformers.TrainerCallback):
|
|
def save_model(self, args, state, kwargs):
|
|
logger.info('Saving PEFT checkpoint...')
|
|
if state.best_model_checkpoint is not None:
|
|
checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
|
|
else:
|
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
|
|
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
|
kwargs["model"].save_pretrained(peft_model_path)
|
|
kwargs["tokenizer"].save_pretrained(peft_model_path)
|
|
|
|
def on_save(self, args, state, control, **kwargs):
|
|
self.save_model(args, state, kwargs)
|
|
return control
|
|
|
|
def on_train_end(self, args, state, control, **kwargs):
|
|
def touch(fname, times=None):
|
|
with open(fname, 'a'):
|
|
os.utime(fname, times)
|
|
touch(os.path.join(args.output_dir, 'completed'))
|
|
self.save_model(args, state, kwargs)
|
|
|
|
def get_last_checkpoint(checkpoint_dir):
|
|
if os.path.isdir(checkpoint_dir):
|
|
is_completed = os.path.exists(os.path.join(checkpoint_dir, 'completed'))
|
|
if is_completed: return None # already finished
|
|
max_step = 0
|
|
for filename in os.listdir(checkpoint_dir):
|
|
if os.path.isdir(os.path.join(checkpoint_dir, filename)) and filename.startswith(PREFIX_CHECKPOINT_DIR):
|
|
max_step = max(max_step, int(filename.replace(PREFIX_CHECKPOINT_DIR + '-', '')))
|
|
if max_step == 0: return None
|
|
latest_ckpt_dir = os.path.join(checkpoint_dir, f'{PREFIX_CHECKPOINT_DIR}-{max_step}')
|
|
logger.info(f"Found a previous checkpoint at: {checkpoint_dir}")
|
|
return latest_ckpt_dir
|
|
return None # first training
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
|
"""Collects the state dict and dump to disk."""
|
|
state_dict = trainer.model.state_dict()
|
|
if trainer.args.should_save:
|
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
|
del state_dict
|
|
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
|
"""Tokenize a list of strings."""
|
|
tokenized_list = [
|
|
tokenizer(
|
|
text,
|
|
# return_tensors="pt",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
)
|
|
for text in strings
|
|
]
|
|
input_ids = labels = [np.array(tokenized.input_ids) for tokenized in tokenized_list]
|
|
input_ids_lens = labels_lens = [
|
|
len(tokenized.input_ids) for tokenized in tokenized_list
|
|
]
|
|
|
|
return dict(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
input_ids_lens=input_ids_lens,
|
|
labels_lens=labels_lens,
|
|
)
|
|
|
|
|
|
def preprocess(
|
|
sources: Sequence[str],
|
|
targets: Sequence[str],
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
) -> Dict:
|
|
"""Preprocess the data by tokenizing."""
|
|
examples = [s + t for s, t in zip(sources, targets)]
|
|
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
|
input_ids = examples_tokenized["input_ids"]
|
|
labels = copy.deepcopy(input_ids)
|
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
|
label[:source_len] = IGNORE_INDEX
|
|
return dict(input_ids=input_ids, labels=labels)
|
|
|
|
@dataclass
|
|
class DataCollatorForSupervisedDataset(object):
|
|
"""Collate examples for supervised fine-tuning."""
|
|
tokenizer: transformers.PreTrainedTokenizer
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
|
input_ids = [torch.tensor(x) for x in input_ids]
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
)
|
|
labels = [torch.tensor(x) for x in labels]
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
|
|
|
return dict(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
|
)
|
|
|
|
def train_tokenize_function(examples, tokenizer):
|
|
sources = [
|
|
build_instruction_prompt(instruction)
|
|
for instruction in examples['instruction']
|
|
]
|
|
targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']]
|
|
data_dict = preprocess(sources, targets, tokenizer)
|
|
return data_dict
|
|
|
|
def build_model(model_args, training_args, checkpoint_dir):
|
|
if not model_args.use_lora: assert model_args.bits in [16, 32]
|
|
compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
load_in_4bit=model_args.bits == 4,
|
|
load_in_8bit=model_args.bits == 8,
|
|
quantization_config=BitsAndBytesConfig(
|
|
load_in_4bit=model_args.bits == 4,
|
|
load_in_8bit=model_args.bits == 8,
|
|
llm_int8_threshold=6.0,
|
|
llm_int8_has_fp16_weight=False,
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
bnb_4bit_use_double_quant=model_args.double_quant,
|
|
bnb_4bit_quant_type=model_args.quant_type,
|
|
) if model_args.use_lora else None,
|
|
torch_dtype=compute_dtype,
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
if compute_dtype == torch.float16 and model_args.bits == 4:
|
|
if torch.cuda.is_bf16_supported():
|
|
logger.info('='*80)
|
|
logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
|
|
logger.info('='*80)
|
|
setattr(model, 'model_parallel', True)
|
|
setattr(model, 'is_parallelizable', True)
|
|
model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
|
|
# Tokenizer
|
|
|
|
if model_args.use_lora and model_args.bits < 16:
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
|
|
|
if model_args.use_lora:
|
|
if checkpoint_dir is not None:
|
|
logger.info(f"Loading adapters from {checkpoint_dir}.")
|
|
# os.path.join(checkpoint_dir, 'adapter_model')
|
|
model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)
|
|
else:
|
|
logger.info(f'Init LoRA modules...')
|
|
target_modules = model_args.trainable.split(',')
|
|
modules_to_save = model_args.modules_to_save
|
|
if modules_to_save is not None:
|
|
modules_to_save = modules_to_save.split(',')
|
|
lora_rank = model_args.lora_rank
|
|
lora_dropout = model_args.lora_dropout
|
|
lora_alpha = model_args.lora_alpha
|
|
peft_config = LoraConfig(
|
|
task_type=TaskType.CAUSAL_LM,
|
|
target_modules=target_modules,
|
|
inference_mode=False,
|
|
r=lora_rank, lora_alpha=lora_alpha,
|
|
lora_dropout=lora_dropout,
|
|
modules_to_save=modules_to_save)
|
|
model = get_peft_model(model, peft_config)
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, LoraLayer):
|
|
if training_args.bf16:
|
|
module = module.to(torch.bfloat16)
|
|
if 'norm' in name or 'gate' in name:
|
|
module = module.to(torch.float32)
|
|
if 'lm_head' in name or 'embed_tokens' in name:
|
|
if hasattr(module, 'weight'):
|
|
if training_args.bf16 and module.weight.dtype == torch.float32:
|
|
module = module.to(torch.bfloat16)
|
|
return model
|
|
|
|
def train():
|
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
log_level = training_args.get_process_log_level()
|
|
logger.setLevel(log_level)
|
|
datasets.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.enable_default_handler()
|
|
transformers.utils.logging.enable_explicit_format()
|
|
if training_args.local_rank == 0:
|
|
logger.info('='*100)
|
|
logger.info(training_args)
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
model_max_length=training_args.model_max_length,
|
|
padding_side="right",
|
|
use_fast=True,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
logger.info("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id)
|
|
logger.info("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id)
|
|
logger.info("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id)
|
|
|
|
if training_args.local_rank == 0:
|
|
logger.info("Load tokenizer from {} over.".format(model_args.model_name_or_path))
|
|
|
|
resume_from_checkpoint_dir = get_last_checkpoint(training_args.output_dir)
|
|
model = build_model(model_args, training_args, resume_from_checkpoint_dir)
|
|
|
|
raw_train_datasets = load_dataset(
|
|
'parquet',
|
|
data_files=data_args.data_path,
|
|
split="train",
|
|
cache_dir=training_args.cache_dir
|
|
)
|
|
if training_args.local_rank > 0:
|
|
torch.distributed.barrier()
|
|
|
|
train_dataset = raw_train_datasets.map(
|
|
train_tokenize_function,
|
|
batched=True,
|
|
batch_size=3000,
|
|
num_proc=32,
|
|
remove_columns=raw_train_datasets.column_names,
|
|
load_from_cache_file=True, # not args.overwrite_cache
|
|
desc="Running Encoding",
|
|
fn_kwargs={ "tokenizer": tokenizer }
|
|
)
|
|
|
|
if training_args.local_rank == 0:
|
|
torch.distributed.barrier()
|
|
|
|
if training_args.local_rank == 0:
|
|
logger.info("Training dataset samples:", len(train_dataset))
|
|
for index in random.sample(range(len(train_dataset)), 3):
|
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.")
|
|
logger.info(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.")
|
|
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
|
data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
|
|
|
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
|
if model_args.use_lora:
|
|
trainer.add_callback(SavePeftModelCallback)
|
|
trainer.train(resume_from_checkpoint = resume_from_checkpoint_dir)
|
|
trainer.save_state()
|
|
if not model_args.use_lora:
|
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
|
|
|
if __name__ == "__main__":
|
|
train()
|