mirror of
https://github.com/deepseek-ai/DeepSeek-Coder
synced 2024-12-04 18:14:44 +00:00
add deepspeed finetune
This commit is contained in:
parent
118e71a1af
commit
4f0b860d30
44
finetune/README.md
Normal file
44
finetune/README.md
Normal file
@ -0,0 +1,44 @@
|
||||
## How to Fine-tune DeepSeek-Coder
|
||||
|
||||
We provide script `finetune_deepseekcoder.py` for users to finetune our models on downstream tasks.
|
||||
|
||||
The script supports the training with [DeepSpeed](https://github.com/microsoft/DeepSpeed). You need install required packages by:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Please follow [Sample Dataset Format](https://huggingface.co/datasets/nickrosh/Evol-Instruct-Code-80k-v1) to prepare your training data.
|
||||
Each line is a json-serialized string with two required fields `instruction` and `output`.
|
||||
|
||||
After data preparation, you can use the sample shell script to finetune `deepseek-ai/deepseek-coder-6.7b-instruct`.
|
||||
Remember to specify `DATA_PATH`, `OUTPUT_PATH`.
|
||||
And please choose appropriate hyper-parameters(e.g., `learning_rate`, `per_device_train_batch_size`) according to your scenario.
|
||||
|
||||
```bash
|
||||
DATA_PATH="<your_data_path>"
|
||||
OUTPUT_PATH="<your_output_path>"
|
||||
MODEL="deepseek-ai/deepseek-coder-6.7b-instruct"
|
||||
|
||||
deepspeed finetune_deepseekcoder.py \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--data_path $DATA_PATH \
|
||||
--output_dir $OUTPUT_PATH \
|
||||
--num_train_epochs 3 \
|
||||
--model_max_length 1024 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 100 \
|
||||
--save_total_limit 100 \
|
||||
--learning_rate 2e-5 \
|
||||
--warmup_steps 10 \
|
||||
--logging_steps 1 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--gradient_checkpointing True \
|
||||
--report_to "tensorboard" \
|
||||
--deepspeed configs/ds_config_zero3.json \
|
||||
--bf16 True
|
||||
```
|
51
finetune/configs/ds_config_zero3.json
Normal file
51
finetune/configs/ds_config_zero3.json
Normal file
@ -0,0 +1,51 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 20,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
207
finetune/finetune_deepseekcoder.py
Normal file
207
finetune/finetune_deepseekcoder.py
Normal file
@ -0,0 +1,207 @@
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import transformers
|
||||
from transformers import Trainer
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
EOT_TOKEN = "<|EOT|>"
|
||||
|
||||
def build_instruction_prompt(instruction: str):
|
||||
return '''
|
||||
You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
|
||||
### Instruction:
|
||||
{}
|
||||
### Response:
|
||||
'''.format(instruction.strip()).lstrip()
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(default="deepseek-ai/deepseek-coder-6.7b-instruct")
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
cache_dir: Optional[str] = field(default=None)
|
||||
optim: str = field(default="adamw_torch")
|
||||
model_max_length: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
||||
)
|
||||
|
||||
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||
"""Collects the state dict and dump to disk."""
|
||||
state_dict = trainer.model.state_dict()
|
||||
if trainer.args.should_save:
|
||||
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
||||
del state_dict
|
||||
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = [torch.tensor(x) for x in input_ids]
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
labels = [torch.tensor(x) for x in labels]
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
def train_tokenize_function(examples, tokenizer):
|
||||
sources = [
|
||||
build_instruction_prompt(instruction)
|
||||
for instruction in examples['instruction']
|
||||
]
|
||||
targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']]
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
return data_dict
|
||||
|
||||
def train():
|
||||
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if training_args.local_rank == 0:
|
||||
print('='*100)
|
||||
print(training_args)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
model_max_length=training_args.model_max_length,
|
||||
padding_side="right",
|
||||
use_fast=True,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id)
|
||||
print("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id)
|
||||
print("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id)
|
||||
|
||||
if training_args.local_rank == 0:
|
||||
print("Load tokenizer from {} over.".format(model_args.model_name_or_path))
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
if training_args.local_rank == 0:
|
||||
print("Load model from {} over.".format(model_args.model_name_or_path))
|
||||
|
||||
|
||||
raw_train_datasets = load_dataset(
|
||||
'json',
|
||||
data_files=data_args.data_path,
|
||||
split="train",
|
||||
cache_dir=training_args.cache_dir
|
||||
)
|
||||
if training_args.local_rank > 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
train_dataset = raw_train_datasets.map(
|
||||
train_tokenize_function,
|
||||
batched=True,
|
||||
batch_size=3000,
|
||||
num_proc=32,
|
||||
remove_columns=raw_train_datasets.column_names,
|
||||
load_from_cache_file=True, # not args.overwrite_cache
|
||||
desc="Running Encoding",
|
||||
fn_kwargs={ "tokenizer": tokenizer }
|
||||
)
|
||||
|
||||
if training_args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if training_args.local_rank == 0:
|
||||
print("Training dataset samples:", len(train_dataset))
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
print(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.")
|
||||
print(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.")
|
||||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
||||
|
||||
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_state()
|
||||
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
10
finetune/requirements.txt
Normal file
10
finetune/requirements.txt
Normal file
@ -0,0 +1,10 @@
|
||||
torch>=2.0.1
|
||||
tokenizers>=0.14.0
|
||||
transformers>=4.35.0
|
||||
accelerate
|
||||
attrdict
|
||||
tqdm
|
||||
|
||||
deepspeed
|
||||
datasets
|
||||
tensorboardX
|
Loading…
Reference in New Issue
Block a user