add training code

This commit is contained in:
Xingkai Yu 2024-08-09 16:23:48 +08:00 committed by ZihanWang314
parent 26b4fc4a8a
commit 809d0e377e
11 changed files with 2533 additions and 8 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__/

31
configs/base.yaml Normal file
View File

@ -0,0 +1,31 @@
seed: 5934875
# Model settings
seq_length: 4096 # Maximum sequence length
# Data settings
per_device_batch_size: 1
n_device: 8 # Number of devices
# Training settings
optim: adamw_torch_fused
steps: 500 # Number of training steps
learning_rate: 0.00001 # Learning rate
weight_decay: 0.1 # Weight decay for optimizer
warmup_steps: 0 # Number of warmup steps for learning rate scheduler
logging_steps: 10 # Log every X steps
adam_beta1: 0.9
adam_beta2: 0.95
random_concat_ratio: 0.2 # Ratio of random concatenation
# Evaluation settings
eval_steps: 100 # Evaluate every X steps
save_steps: 100 # Save model every X steps
# Tokenizer settings
# Additional settings (if needed)
gradient_checkpointing: true
gradient_accumulation_steps: 16 # Number of updates steps to accumulate before performing a backward/update pass
max_grad_norm: 1.0 # Max gradient norm for gradient clipping
ep_size: 2

0
deepseek/__init__.py Normal file
View File

View File

@ -0,0 +1,206 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V2.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 102400):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV2Model, DeepseekV2Config
>>> # Initializing a Deepseek-V2 style configuration
>>> configuration = DeepseekV2Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size = 1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts = None,
n_routed_experts = None,
ep_size = 1,
routed_scaling_factor = 1.0,
kv_lora_rank = 512,
q_lora_rank = 1536,
qk_rope_head_dim = 64,
v_head_dim = 128,
qk_nope_head_dim = 128,
topk_method = 'gready',
n_group = None,
topk_group = None,
num_experts_per_tok = None,
moe_layer_freq = 1,
first_k_dense_replace = 0,
norm_topk_prob = False,
scoring_func = 'softmax',
aux_loss_alpha = 0.001,
seq_aux = True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

File diff suppressed because it is too large Load Diff

29
esft.py
View File

@ -7,6 +7,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
def to_buffer(module, mark_param=True): def to_buffer(module, mark_param=True):
"""Turns all parameters of a module into buffers.""" """Turns all parameters of a module into buffers."""
if module is None:
return
modules = module.modules() modules = module.modules()
module = next(modules) module = next(modules)
delattrs = [] delattrs = []
@ -25,6 +27,8 @@ def to_buffer(module, mark_param=True):
def to_param(module): def to_param(module):
"""Turns all buffers of a module into parameterss.""" """Turns all buffers of a module into parameterss."""
if module is None:
return
modules = module.modules() modules = module.modules()
module = next(modules) module = next(modules)
param_list = getattr(module, 'param_list', []) param_list = getattr(module, 'param_list', [])
@ -57,7 +61,7 @@ def to_esft(model, adapter_config):
to_buffer(model) to_buffer(model)
else: else:
to_param(model) to_param(model)
for idx, layer in enumerate(model.layers): for idx, layer in enumerate(model.model.layers):
if type(layer.mlp).__name__ != "DeepseekV2MoE": if type(layer.mlp).__name__ != "DeepseekV2MoE":
continue continue
if adapter_config.get('shared_experts', False): if adapter_config.get('shared_experts', False):
@ -72,15 +76,25 @@ def to_esft(model, adapter_config):
to_buffer(layer.mlp.experts[expert_id]) to_buffer(layer.mlp.experts[expert_id])
return model return model
def load_state_dict(folder_path): def load_state_dict(folder_path):
# 初始化空的 state_dict
combined_state_dict = {} combined_state_dict = {}
# 遍历文件夹中的所有文件
for file_name in os.listdir(folder_path): for file_name in os.listdir(folder_path):
if file_name.endswith('.safetensors'): if file_name.endswith('.safetensors'):
file_path = os.path.join(folder_path, file_name) file_path = os.path.join(folder_path, file_name)
state_dict = load_file(file_path) state_dict = load_file(file_path)
combined_state_dict.update(state_dict) combined_state_dict.update(state_dict)
# legacy for loading v1 checkpoints: add prefix "model." for parameters
for k in list(combined_state_dict.keys()):
if k.startswith("layers"):
k_new = "model." + k
combined_state_dict[k_new] = combined_state_dict[k]
del combined_state_dict[k]
return combined_state_dict return combined_state_dict
@ -89,20 +103,23 @@ def load_esft_model(base_model_path, adapter_dir):
adapter_state_dict = load_state_dict(adapter_dir) adapter_state_dict = load_state_dict(adapter_dir)
# load pretrained model: # load pretrained model:
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path) model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path)
to_esft(model.model, adapter_config) to_esft(model, adapter_config)
model.model.load_state_dict(adapter_state_dict) model.load_state_dict(adapter_state_dict)
return model, tokenizer return model, tokenizer
def load_base_model(base_model_path): def load_base_model(base_model_path):
# load pretrained model: # load pretrained model:
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path) model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path)
return model, tokenizer return model, tokenizer
def add_adapter(base_model, adapter_dir, return_original_states=False): def add_adapter(base_model, adapter_dir, return_original_states=False, expert_config=None):
if expert_config is not None:
adapter_config = json.load(open(expert_config))
else:
adapter_config = json.load(open(adapter_dir + "/expert_cfg.json")) adapter_config = json.load(open(adapter_dir + "/expert_cfg.json"))
adapter_state_dict = load_state_dict(adapter_dir) adapter_state_dict = load_state_dict(adapter_dir)

12
scripts/train.sh Normal file
View File

@ -0,0 +1,12 @@
export TOKENIZERS_PARALLELISM=false
exp_name="test/eval_translation"
base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model"
# turn above to for loop
python train.py \
--base_model_path=${base_model_path} \
--expert_config=results/expert_configs/translation.json \
--train_dataset=translation \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/${exp_name}

11
scripts/train_ep.sh Normal file
View File

@ -0,0 +1,11 @@
export TOKENIZERS_PARALLELISM=false
exp_name="test/eval_translation"
base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model"
torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=${base_model_path} \
--expert_config=results/expert_configs/translation.json \
--train_dataset=translation \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/${exp_name}

117
train.py Normal file
View File

@ -0,0 +1,117 @@
import argparse
import json
import yaml
import os
import random
import torch
from torch.utils.data import TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging
from benchmarks import *
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad
from esft import to_esft
from deepseek.modeling_deepseek import DeepseekV2ForCausalLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--expert_config", type=str, required=True)
parser.add_argument("--train_dataset", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--train_config", type=str, required=True)
parser.add_argument("--wandb_api_key", type=str, required=False)
args = parser.parse_args()
expert_config = json.load(open(args.expert_config))
output_dir = args.output_dir
base_model_path = args.base_model_path
config = yaml.safe_load(open(args.train_config))
os.makedirs(args.output_dir, exist_ok=True)
seed = config['seed']
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
if args.wandb_api_key is not None:
import wandb
wandb.login(key=args.wandb_api_key)
# Prepare data
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()]
buffer = []
for instance in samples:
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
buffer.append((input_ids, target_ids))
seq_length = config['seq_length']
random_concat_ratio = config['random_concat_ratio']
concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
max_steps=config['steps'],
per_device_train_batch_size=config['per_device_batch_size'],
per_device_eval_batch_size=config['per_device_batch_size'],
warmup_steps=config['warmup_steps'],
weight_decay=config['weight_decay'],
logging_dir=f"{output_dir}/logs",
logging_steps=config['logging_steps'],
save_steps=config['save_steps'],
eval_strategy="steps",
eval_steps=config['eval_steps'],
gradient_accumulation_steps=config['gradient_accumulation_steps'],
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
bf16=True,
lr_scheduler_type='constant',
save_total_limit=5,
learning_rate=config['learning_rate'],
optim=config['optim'],
adam_beta1=config['adam_beta1'],
adam_beta2=config['adam_beta2'],
gradient_checkpointing=config['gradient_checkpointing'],
gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug
)
def data_collator(data):
input_ids = torch.stack([item[0] for item in data])
labels = torch.stack([item[1] for item in data])
return {"input_ids": input_ids, "labels": labels}
model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
to_esft(model, expert_config)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=data_collator,
)
# Training
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
# Save the model and tokenizer
trainer.save_model(output_dir + "/last_checkpoint")
tokenizer.save_pretrained(output_dir + "/last_checkpoint")
print("Training complete")
if __name__ == "__main__":
main()

154
train_ep.py Normal file
View File

@ -0,0 +1,154 @@
import argparse
import json
import yaml
import os
import random
import torch
import torch.distributed as dist
from types import MethodType
from torch.utils.data import TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging
from benchmarks import *
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad, init_parallel_groups
from esft import to_esft
from deepseek.modeling_deepseek import DeepseekV2ForCausalLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["NCCL_AVOID_RECORD_STREAMS"] = "1"
logging.set_verbosity_error()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--expert_config", type=str, required=True)
parser.add_argument("--train_dataset", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--train_config", type=str, required=True)
parser.add_argument("--wandb_api_key", type=str, required=False)
args = parser.parse_args()
expert_config = json.load(open(args.expert_config))
output_dir = args.output_dir
base_model_path = args.base_model_path
config = yaml.safe_load(open(args.train_config))
os.makedirs(args.output_dir, exist_ok=True)
seed = config['seed']
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
if args.wandb_api_key is not None:
import wandb
wandb.login(key=args.wandb_api_key)
ep_size = config.get("ep_size", 1)
world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size)
edp_size = world_size // ep_size
# Prepare data
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()]
buffer = []
for instance in samples:
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
buffer.append((input_ids, target_ids))
seq_length = config['seq_length']
random_concat_ratio = config['random_concat_ratio']
concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
max_steps=config['steps'],
per_device_train_batch_size=config['per_device_batch_size'],
per_device_eval_batch_size=config['per_device_batch_size'],
warmup_steps=config['warmup_steps'],
weight_decay=config['weight_decay'],
logging_dir=f"{output_dir}/logs",
logging_steps=config['logging_steps'],
save_steps=config['save_steps'],
eval_strategy="steps",
eval_steps=config['eval_steps'],
gradient_accumulation_steps=config['gradient_accumulation_steps'],
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
bf16=True,
lr_scheduler_type='constant',
save_total_limit=5,
learning_rate=config['learning_rate'],
optim=config['optim'],
adam_beta1=config['adam_beta1'],
adam_beta2=config['adam_beta2'],
disable_tqdm=False,
gradient_checkpointing=config['gradient_checkpointing'],
gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug
)
def data_collator(data):
input_ids = torch.stack([item[0] for item in data])
labels = torch.stack([item[1] for item in data])
return {"input_ids": input_ids, "labels": labels}
model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16,
ep_size=ep_size, attn_implementation="flash_attention_2")
model._ddp_params_and_buffers_to_ignore = [n for n, _ in model.named_parameters() if ".expert" in n] # we manage grad synchronization of expert parameters
to_esft(model, expert_config)
model.dummy = torch.nn.Parameter(torch.zeros(1, dtype=model.dtype)) # prevent DDP from having no trainable parameters
model._keys_to_ignore_on_save = ["dummy"]
expert_params = [p for n, p in model.named_parameters() if p.requires_grad and ".expert" in n]
for layer in model.model.layers:
if type(layer.mlp).__name__ != "DeepseekV2MoE":
continue
layer.mlp.ep_group = ep_group
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=data_collator,
)
accelerator = trainer.accelerator
backward = accelerator.backward
def custom_backward(self, loss, **kwargs):
backward(loss, **kwargs)
if not self.sync_gradients or edp_size == 1:
return
return
for p in expert_params:
g = p.grad if p.grad is not None else torch.zeros_like(p)
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
if p.grad is not g:
p.grad = g
accelerator.backward = MethodType(custom_backward, accelerator)
# Training
ckpt_path = f"{output_dir}/last_checkpoint_ep{local_rank}"
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already
trainer.train(resume_from_checkpoint=ckpt_path)
else:
trainer.train()
# Save the model and tokenizer
if local_rank == 0:
trainer.save_model(ckpt_path)
tokenizer.save_pretrained(ckpt_path)
elif 0 < local_rank < ep_size:
model.save_pretrained(ckpt_path)
print("Training complete")
if __name__ == "__main__":
main()

View File

@ -1,3 +1,7 @@
import os
import random
import torch
import torch.distributed as dist
# given a message object, convert to prompt and response # given a message object, convert to prompt and response
PROMPT_USER: str = 'User: {input}\n\n' PROMPT_USER: str = 'User: {input}\n\n'
@ -38,3 +42,57 @@ def get_formatted_input_and_target(messages, tokenizer, IGNORE_TOKEN_ID=-100, ma
assert False, f"Unknown role: {message['role']}" assert False, f"Unknown role: {message['role']}"
return [input_ids, target_ids] return [input_ids, target_ids]
def get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio, IGNORE_TOKEN_ID=-100):
all_input_ids_list, all_target_ids_list = [], []
all_input_ids, all_target_ids = [], []
for input_ids, target_ids in buffer:
if len(input_ids) > seq_length - len(all_input_ids):
input_ids = input_ids[-(seq_length - len(all_input_ids)):]
target_ids = target_ids[-(seq_length - len(all_target_ids)):]
if len(all_input_ids) > 0 and random.random() < random_concat_ratio:
input_ids = input_ids[1:]
target_ids = target_ids[1:]
all_input_ids.extend(input_ids)
all_target_ids.extend(target_ids)
if len(all_input_ids) >= seq_length:
assert len(all_input_ids) == seq_length, f"{len(all_input_ids)=}, {seq_length=}, {len(buffer)=}"
all_input_ids_list.append(all_input_ids)
all_target_ids_list.append(all_target_ids)
all_input_ids, all_target_ids = [], []
all_input_ids = all_input_ids + [tokenizer.pad_token_id for i in range(seq_length - len(all_input_ids))]
all_target_ids = all_target_ids + [IGNORE_TOKEN_ID for i in range(seq_length - len(all_target_ids))]
all_input_ids_list.append(all_input_ids)
all_target_ids_list.append(all_target_ids)
if len(all_input_ids) <= 0:
return None
return {
"input_ids": torch.tensor(all_input_ids_list, dtype=torch.long),
"labels": torch.tensor(all_target_ids_list, dtype=torch.long)
}
def init_parallel_groups(ep_size=1):
dist.init_process_group("nccl")
world_size = int(os.getenv("WORLD_SIZE", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
ep_group = edp_group = None
for i in range(0, world_size, ep_size):
ranks = list(range(i, i + ep_size))
group = dist.new_group(ranks)
if local_rank in ranks:
ep_group = group
edp_group = None
for i in range(ep_size):
ranks = list(range(i, world_size, ep_size))
group = dist.new_group(ranks)
if local_rank in ranks:
edp_group = group
dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group)
dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group)
return world_size, local_rank, ep_group, edp_group