mirror of
https://github.com/deepseek-ai/DeepSeek-Coder-V2
synced 2024-11-25 05:19:44 +00:00
add finetuning code with lora support
This commit is contained in:
parent
09209c4aed
commit
19bdb69bb5
90
finetune/README.md
Normal file
90
finetune/README.md
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
## How to Fine-tune DeepSeek-Coder-V2
|
||||||
|
|
||||||
|
We provide script `finetune_deepseekcoder.py` for users to finetune our models on downstream tasks.
|
||||||
|
|
||||||
|
The script supports the training with [DeepSpeed](https://github.com/microsoft/DeepSpeed). You need install required packages by:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install flash-attn --no-build-isolation
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Please follow [Sample Dataset Format](https://huggingface.co/datasets/nickrosh/Evol-Instruct-Code-80k-v1) to prepare your training data.
|
||||||
|
|
||||||
|
You can download the sample dataset from [HuggingFace](https://huggingface.co/datasets/nickrosh/Evol-Instruct-Code-80k-v1) by:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget https://huggingface.co/datasets/nickrosh/Evol-Instruct-Code-80k-v1/resolve/main/EvolInstruct-Code-80k.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Each line is a json-serialized string with two required fields `instruction` and `output`.
|
||||||
|
|
||||||
|
After data preparation, you can use the sample shell script to finetune `deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct`. Remember to specify `DATA_PATH`, `OUTPUT_PATH`.
|
||||||
|
And please choose appropriate hyper-parameters(e.g., `learning_rate`, `per_device_train_batch_size`) according to your scenario. For devices supported by flash_attention, you can refer [here](https://github.com/Dao-AILab/flash-attention).
|
||||||
|
For this configuration, zero_stage needs to be set to 3.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DATA_PATH="<your_data_path>"
|
||||||
|
OUTPUT_PATH="<your_output_path>"
|
||||||
|
MODEL_PATH="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||||
|
|
||||||
|
deepspeed finetune_deepseekcoder.py \
|
||||||
|
--model_name_or_path $MODEL_PATH \
|
||||||
|
--data_path $DATA_PATH \
|
||||||
|
--output_dir $OUTPUT_PATH \
|
||||||
|
--num_train_epochs 3 \
|
||||||
|
--model_max_length 1024 \
|
||||||
|
--per_device_train_batch_size 16 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--evaluation_strategy "no" \
|
||||||
|
--save_strategy "steps" \
|
||||||
|
--save_steps 100 \
|
||||||
|
--save_total_limit 100 \
|
||||||
|
--learning_rate 2e-5 \
|
||||||
|
--warmup_steps 10 \
|
||||||
|
--logging_steps 1 \
|
||||||
|
--lr_scheduler_type "cosine" \
|
||||||
|
--gradient_checkpointing True \
|
||||||
|
--report_to "tensorboard" \
|
||||||
|
--deepspeed configs/ds_config_zero3.json \
|
||||||
|
--bf16 True \
|
||||||
|
--use_lora False
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also finetune the model with 4/8-bits qlora, feel free to try it. For this configuration, it is possible to run on a single A100 80G GPU, and adjustments can be made according to your resources.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DATA_PATH="<your_data_path>"
|
||||||
|
OUTPUT_PATH="<your_output_path>"
|
||||||
|
MODEL_PATH="<your_model_path>"
|
||||||
|
|
||||||
|
deepspeed finetune.py \
|
||||||
|
--model_name_or_path $MODEL_PATH \
|
||||||
|
--data_path $DATA_PATH \
|
||||||
|
--output_dir $OUTPUT_PATH \
|
||||||
|
--num_train_epochs 3 \
|
||||||
|
--model_max_length 1024 \
|
||||||
|
--per_device_train_batch_size 16 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--evaluation_strategy "no" \
|
||||||
|
--save_strategy "steps" \
|
||||||
|
--save_steps 100 \
|
||||||
|
--save_total_limit 100 \
|
||||||
|
--learning_rate 2e-5 \
|
||||||
|
--warmup_steps 10 \
|
||||||
|
--logging_steps 1 \
|
||||||
|
--lr_scheduler_type "cosine" \
|
||||||
|
--gradient_checkpointing True \
|
||||||
|
--report_to "tensorboard" \
|
||||||
|
--deepspeed configs/ds_config_zero2_no_offload.json \
|
||||||
|
--bf16 True \
|
||||||
|
--use_lora True \
|
||||||
|
--bits 4 \
|
||||||
|
--max_grad_norm 0.3 \
|
||||||
|
--double_quant \
|
||||||
|
--lora_r 64 \
|
||||||
|
--lora_alpha 16 \
|
||||||
|
--quant_type nf4 \
|
||||||
|
```
|
22
finetune/configs/ds_config_zero2_no_offload.json
Normal file
22
finetune/configs/ds_config_zero2_no_offload.json
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 1e8,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 1e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
},
|
||||||
|
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 2000,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
51
finetune/configs/ds_config_zero3.json
Normal file
51
finetune/configs/ds_config_zero3.json
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
{
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 1e9,
|
||||||
|
"stage3_max_reuse_distance": 1e9,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"steps_per_print": 20,
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
366
finetune/finetune_deepseekcoder.py
Normal file
366
finetune/finetune_deepseekcoder.py
Normal file
@ -0,0 +1,366 @@
|
|||||||
|
import copy
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import transformers
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Dict, Sequence
|
||||||
|
from transformers import Trainer, BitsAndBytesConfig
|
||||||
|
from datasets import load_dataset
|
||||||
|
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
||||||
|
from peft.tuners.lora import LoraLayer
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from rich import print
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
EOT_TOKEN = "<|EOT|>"
|
||||||
|
|
||||||
|
def build_instruction_prompt(instruction: str):
|
||||||
|
return '''
|
||||||
|
You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
|
||||||
|
### Instruction:
|
||||||
|
{}
|
||||||
|
### Response:
|
||||||
|
'''.format(instruction.strip()).lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
trainable : Optional[str] = field(default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj")
|
||||||
|
lora_rank : Optional[int] = field(default=8)
|
||||||
|
lora_dropout : Optional[float] = field(default=0.1)
|
||||||
|
lora_alpha : Optional[float] = field(default=32.)
|
||||||
|
modules_to_save : Optional[str] = field(default="embed_tokens,lm_head")
|
||||||
|
use_lora : Optional[bool] = field(default=False)
|
||||||
|
model_name_or_path: Optional[str] = field(default="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
|
||||||
|
attn_implementation : Optional[str] = field(default="flash_attention_2")
|
||||||
|
double_quant: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Compress the quantization statistics through double quantization."}
|
||||||
|
)
|
||||||
|
quant_type: str = field(
|
||||||
|
default="nf4",
|
||||||
|
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
||||||
|
)
|
||||||
|
bits: int = field(
|
||||||
|
default=16,
|
||||||
|
metadata={"help": "How many bits to use."}
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataArguments:
|
||||||
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments(transformers.TrainingArguments):
|
||||||
|
cache_dir: Optional[str] = field(default=None)
|
||||||
|
optim: str = field(default="adamw_torch")
|
||||||
|
model_max_length: int = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
||||||
|
)
|
||||||
|
|
||||||
|
class SavePeftModelCallback(transformers.TrainerCallback):
|
||||||
|
def save_model(self, args, state, kwargs):
|
||||||
|
logger.info('Saving PEFT checkpoint...')
|
||||||
|
if state.best_model_checkpoint is not None:
|
||||||
|
checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
|
||||||
|
else:
|
||||||
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
|
|
||||||
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||||
|
kwargs["model"].save_pretrained(peft_model_path)
|
||||||
|
kwargs["tokenizer"].save_pretrained(peft_model_path)
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
self.save_model(args, state, kwargs)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
def touch(fname, times=None):
|
||||||
|
with open(fname, 'a'):
|
||||||
|
os.utime(fname, times)
|
||||||
|
touch(os.path.join(args.output_dir, 'completed'))
|
||||||
|
self.save_model(args, state, kwargs)
|
||||||
|
logger.info("Saved model successfully")
|
||||||
|
|
||||||
|
def get_last_checkpoint(checkpoint_dir):
|
||||||
|
if os.path.isdir(checkpoint_dir):
|
||||||
|
is_completed = os.path.exists(os.path.join(checkpoint_dir, 'completed'))
|
||||||
|
if is_completed: return None # already finished
|
||||||
|
max_step = 0
|
||||||
|
for filename in os.listdir(checkpoint_dir):
|
||||||
|
if os.path.isdir(os.path.join(checkpoint_dir, filename)) and filename.startswith(PREFIX_CHECKPOINT_DIR):
|
||||||
|
max_step = max(max_step, int(filename.replace(PREFIX_CHECKPOINT_DIR + '-', '')))
|
||||||
|
if max_step == 0: return None
|
||||||
|
latest_ckpt_dir = os.path.join(checkpoint_dir, f'{PREFIX_CHECKPOINT_DIR}-{max_step}')
|
||||||
|
logger.info(f"Found a previous checkpoint at: {checkpoint_dir}")
|
||||||
|
return latest_ckpt_dir
|
||||||
|
return None # first training
|
||||||
|
|
||||||
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||||
|
"""Collects the state dict and dump to disk."""
|
||||||
|
state_dict = trainer.model.state_dict()
|
||||||
|
if trainer.args.should_save:
|
||||||
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
||||||
|
del state_dict
|
||||||
|
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
||||||
|
|
||||||
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||||
|
"""Tokenize a list of strings."""
|
||||||
|
tokenized_list = [
|
||||||
|
tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
for text in strings
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||||
|
input_ids_lens = labels_lens = [
|
||||||
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||||
|
]
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=labels,
|
||||||
|
input_ids_lens=input_ids_lens,
|
||||||
|
labels_lens=labels_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
sources: Sequence[str],
|
||||||
|
targets: Sequence[str],
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
) -> Dict:
|
||||||
|
"""Preprocess the data by tokenizing."""
|
||||||
|
examples = [s + t for s, t in zip(sources, targets)]
|
||||||
|
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||||
|
input_ids = examples_tokenized["input_ids"]
|
||||||
|
|
||||||
|
labels = copy.deepcopy(input_ids)
|
||||||
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||||
|
label[:source_len] = IGNORE_INDEX
|
||||||
|
return dict(input_ids=input_ids, labels=labels)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForSupervisedDataset(object):
|
||||||
|
"""Collate examples for supervised fine-tuning."""
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||||
|
input_ids = [torch.tensor(x) for x in input_ids]
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
labels = [torch.tensor(x) for x in labels]
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=labels,
|
||||||
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_tokenize_function(examples, tokenizer):
|
||||||
|
sources = [
|
||||||
|
build_instruction_prompt(instruction)
|
||||||
|
for instruction in examples['instruction']
|
||||||
|
]
|
||||||
|
targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']]
|
||||||
|
data_dict = preprocess(sources, targets, tokenizer)
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
def build_model(model_args, training_args, checkpoint_dir):
|
||||||
|
logger.info("Starting model building process...")
|
||||||
|
if not model_args.use_lora:
|
||||||
|
assert model_args.bits in [16, 32]
|
||||||
|
logger.info(f"Not using LoRA. Model bits: {model_args.bits}")
|
||||||
|
compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
|
||||||
|
logger.info(f"Compute dtype: {compute_dtype}")
|
||||||
|
|
||||||
|
logger.info(f"Loading model from: {model_args.model_name_or_path}")
|
||||||
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=model_args.bits == 4,
|
||||||
|
load_in_8bit=model_args.bits == 8,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
|
bnb_4bit_use_double_quant=model_args.double_quant,
|
||||||
|
bnb_4bit_quant_type=model_args.quant_type,
|
||||||
|
) if model_args.use_lora else None,
|
||||||
|
torch_dtype=compute_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
|
||||||
|
if compute_dtype == torch.float16 and model_args.bits == 4:
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
logger.info('='*80)
|
||||||
|
logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
|
||||||
|
logger.info('='*80)
|
||||||
|
|
||||||
|
logger.info("Setting model attributes...")
|
||||||
|
setattr(model, 'model_parallel', True)
|
||||||
|
setattr(model, 'is_parallelizable', True)
|
||||||
|
model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
|
||||||
|
logger.info(f"Model torch dtype set to: {model.config.torch_dtype}")
|
||||||
|
|
||||||
|
if model_args.use_lora and model_args.bits < 16:
|
||||||
|
logger.info("Preparing model for k-bit training...")
|
||||||
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
||||||
|
logger.info("Model prepared for k-bit training")
|
||||||
|
|
||||||
|
if model_args.use_lora:
|
||||||
|
logger.info("LoRA is enabled. Proceeding with LoRA setup...")
|
||||||
|
if checkpoint_dir is not None:
|
||||||
|
logger.info(f"Loading adapters from {checkpoint_dir}.")
|
||||||
|
model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)
|
||||||
|
else:
|
||||||
|
logger.info(f'Init LoRA modules...')
|
||||||
|
target_modules = model_args.trainable.split(',')
|
||||||
|
logger.info(f"Target modules for LoRA: {target_modules}")
|
||||||
|
|
||||||
|
modules_to_save = model_args.modules_to_save
|
||||||
|
if modules_to_save is not None:
|
||||||
|
modules_to_save = modules_to_save.split(',')
|
||||||
|
logger.info(f"Modules to save: {modules_to_save}")
|
||||||
|
else:
|
||||||
|
logger.info("No modules to save specified")
|
||||||
|
|
||||||
|
lora_rank = model_args.lora_rank
|
||||||
|
lora_dropout = model_args.lora_dropout
|
||||||
|
lora_alpha = model_args.lora_alpha
|
||||||
|
logger.info(f"LoRA parameters: rank={lora_rank}, dropout={lora_dropout}, alpha={lora_alpha}")
|
||||||
|
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
target_modules=target_modules,
|
||||||
|
inference_mode=False,
|
||||||
|
r=lora_rank, lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
modules_to_save=modules_to_save)
|
||||||
|
logger.info(f"LoRA configuration: {peft_config}")
|
||||||
|
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
logger.info("LoRA model preparation completed")
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LoraLayer):
|
||||||
|
if training_args.bf16:
|
||||||
|
module = module.to(torch.bfloat16)
|
||||||
|
if 'norm' in name or 'gate' in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
if 'lm_head' in name or 'embed_tokens' in name:
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
if training_args.bf16 and module.weight.dtype == torch.float32:
|
||||||
|
module = module.to(torch.bfloat16)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def train():
|
||||||
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
log_level = training_args.get_process_log_level()
|
||||||
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
if training_args.local_rank == 0:
|
||||||
|
logger.info('='*100)
|
||||||
|
logger.info(f"Training arguments: {training_args}")
|
||||||
|
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
model_max_length=training_args.model_max_length,
|
||||||
|
padding_side="right",
|
||||||
|
use_fast=True,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id)
|
||||||
|
print("BOS Token:", tokenizer.bos_token, tokenizer.bos_token_id)
|
||||||
|
print("EOS Token:", tokenizer.eos_token, tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
if training_args.local_rank == 0:
|
||||||
|
logger.info("Load tokenizer from {} over.".format(model_args.model_name_or_path))
|
||||||
|
|
||||||
|
resume_from_checkpoint_dir = get_last_checkpoint(training_args.output_dir)
|
||||||
|
model = build_model(model_args, training_args, resume_from_checkpoint_dir)
|
||||||
|
|
||||||
|
if training_args.local_rank == 0:
|
||||||
|
logger.info("Load model from {} over.".format(model_args.model_name_or_path))
|
||||||
|
|
||||||
|
raw_train_datasets = load_dataset(
|
||||||
|
'json', # can be also parquet, csv, etc.
|
||||||
|
data_files=data_args.data_path,
|
||||||
|
split="train",
|
||||||
|
cache_dir=training_args.cache_dir
|
||||||
|
)
|
||||||
|
if training_args.local_rank > 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
logger.info("Starting dataset mapping")
|
||||||
|
train_dataset = raw_train_datasets.map(
|
||||||
|
train_tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
batch_size=3000,
|
||||||
|
num_proc=os.cpu_count(),
|
||||||
|
remove_columns=raw_train_datasets.column_names,
|
||||||
|
load_from_cache_file=True, # not args.overwrite_cache
|
||||||
|
desc="Running Encoding",
|
||||||
|
fn_kwargs={ "tokenizer": tokenizer }
|
||||||
|
)
|
||||||
|
logger.info("Dataset mapping completed")
|
||||||
|
|
||||||
|
if training_args.local_rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
if training_args.local_rank == 0:
|
||||||
|
logger.info(f"Training dataset samples: {len(train_dataset)}")
|
||||||
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.")
|
||||||
|
logger.info(f"Sample {index} decoded: {tokenizer.decode(list(train_dataset[index]['input_ids']))}")
|
||||||
|
|
||||||
|
logger.info("Creating data collator")
|
||||||
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||||
|
logger.info("Setting up data module")
|
||||||
|
data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
||||||
|
|
||||||
|
logger.info("Initializing Trainer")
|
||||||
|
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
||||||
|
|
||||||
|
if model_args.use_lora:
|
||||||
|
logger.info("Adding SavePeftModelCallback for LoRA")
|
||||||
|
trainer.add_callback(SavePeftModelCallback)
|
||||||
|
|
||||||
|
logger.info("Starting training")
|
||||||
|
trainer.train(resume_from_checkpoint = resume_from_checkpoint_dir)
|
||||||
|
|
||||||
|
logger.info("Saving trainer state")
|
||||||
|
trainer.save_state()
|
||||||
|
|
||||||
|
if not model_args.use_lora:
|
||||||
|
logger.info("Saving full model (non-LoRA)")
|
||||||
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
||||||
|
|
||||||
|
logger.info("Training completed")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
13
finetune/requirements.txt
Normal file
13
finetune/requirements.txt
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
accelerate==0.33.0
|
||||||
|
attrdict==2.0.1
|
||||||
|
bitsandbytes==0.43.3
|
||||||
|
datasets==2.21.0
|
||||||
|
deepspeed==0.14.4
|
||||||
|
peft==0.12.0
|
||||||
|
tokenizers==0.19.1
|
||||||
|
torch==2.4.0
|
||||||
|
tqdm==4.66.5
|
||||||
|
transformers==4.44.2
|
||||||
|
rich
|
||||||
|
tqdm
|
||||||
|
loguru
|
Loading…
Reference in New Issue
Block a user