mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-22 03:27:38 +00:00
Merge pull request #4 from ZihanWang314/main
add (multi-gpu) training and eval
This commit is contained in:
commit
4bfa99486b
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__/
|
59
README.md
59
README.md
@ -10,6 +10,9 @@ Y. Wu.
|
|||||||
**ESFT** aims to efficiently customize Large Language Models (LLMs) with Mixture-of-Experts (MoE) architecture by adjusting only task-relevant parts, improving efficiency and performance while using fewer resources and storage.
|
**ESFT** aims to efficiently customize Large Language Models (LLMs) with Mixture-of-Experts (MoE) architecture by adjusting only task-relevant parts, improving efficiency and performance while using fewer resources and storage.
|
||||||
|
|
||||||
|
|
||||||
|
## 📰 News
|
||||||
|
|
||||||
|
📅 **2024.8.11:** We now release the **ESFT training code**! ✨ You can now try it with your own models and dataset!
|
||||||
|
|
||||||
|
|
||||||
## 🚀 Quick Start
|
## 🚀 Quick Start
|
||||||
@ -19,9 +22,9 @@ git clone https://github.com/deepseek-ai/ESFT.git
|
|||||||
cd esft
|
cd esft
|
||||||
```
|
```
|
||||||
|
|
||||||
### Install dependencies
|
### Install required dependencies
|
||||||
```bash
|
```bash
|
||||||
pip install transformers torch safetensors
|
pip install transformers torch safetensors accelerate
|
||||||
```
|
```
|
||||||
|
|
||||||
### Download necessary adapters
|
### Download necessary adapters
|
||||||
@ -32,35 +35,38 @@ bash scripts/download_adapters.sh
|
|||||||
|
|
||||||
|
|
||||||
## 🔧Key Scripts
|
## 🔧Key Scripts
|
||||||
1. **eval.py**
|
1. **eval_multigpu.py**
|
||||||
This script evaluates the performance of the model on various datasets. **Usage:**
|
This script evaluates the performance of the model on various datasets. See **scripts/eval.sh** for detailed configs and explanations.
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
```bash
|
```bash
|
||||||
python scripts/eval.py \
|
python eval_multigpu.py \
|
||||||
--eval_datasets=translation \
|
--eval_dataset=translation \
|
||||||
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
--adapter_dir=all_models/adapters/token \
|
--adapter_dir=all_models/adapters/token/translation \
|
||||||
--output_dir=results/completions/token \
|
--output_path=results/completions/token/translation.jsonl \
|
||||||
--max_new_tokens=512 \
|
--openai_api_key=YOUR_OPENAI_API_KEY
|
||||||
--openai_api_key=REPLACE_WITH_YOUR_KEY \
|
|
||||||
--eval_batch_size=2
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
2. **get_expert_scores.py**
|
2. **get_expert_scores.py**
|
||||||
This script calculates the scores for each expert based on the evaluation datasets.
|
This script calculates the scores for each expert based on the evaluation datasets.
|
||||||
**Usage:**
|
**Usage:**
|
||||||
```bash
|
```bash
|
||||||
python scripts/get_expert_scores.py \
|
python scripts/expert/get_expert_scores.py \
|
||||||
--eval_datasets=intent,summary,law,translation \
|
--eval_dataset=translation \
|
||||||
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
--output_dir=results/expert_scores \
|
--output_dir=results/expert_scores/translation \
|
||||||
--n_sample_tokens=8192 # the sample size hyperparameter
|
--n_sample_tokens=131072 \
|
||||||
|
--world_size=4 \
|
||||||
|
--gpus_per_rank=2
|
||||||
```
|
```
|
||||||
|
|
||||||
3. **generate_expert_config.py**
|
3. **generate_expert_config.py**
|
||||||
This script generates the configuration to convert a MoE model with only task-relevant tasks trained based on evaluation scores.
|
This script generates the configuration to convert a MoE model with only task-relevant tasks trained based on evaluation scores.
|
||||||
**Usage:**
|
**Usage:**
|
||||||
```bash
|
```bash
|
||||||
python scripts/generate_expert_config.py \
|
python scripts/expert/generate_expert_config.py \
|
||||||
--eval_datasets=intent,summary,law,translation \
|
--eval_datasets=intent,summary,law,translation \
|
||||||
--expert_scores_dir=results/expert_scores \
|
--expert_scores_dir=results/expert_scores \
|
||||||
--output_dir=results/expert_configs \
|
--output_dir=results/expert_configs \
|
||||||
@ -68,13 +74,32 @@ python scripts/generate_expert_config.py \
|
|||||||
--top_p=0.2 # the scoring function and top_p are hyperparameters
|
--top_p=0.2 # the scoring function and top_p are hyperparameters
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4. **train.py** and **train_ep.py**
|
||||||
|
This script trains the model with the expert configuration generated by the previous script. The train_ep.py file uses expert parallel and has been optimized for multi-GPU training.
|
||||||
|
**Usage:**
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
|
--expert_config=results/expert_configs/intent.json \
|
||||||
|
--train_dataset=intent \
|
||||||
|
--train_config=configs/base.yaml \
|
||||||
|
--output_dir=results/checkpoints/intent
|
||||||
|
|
||||||
|
torchrun --nproc-per-node=8 train_ep.py \
|
||||||
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
|
--expert_config=results/expert_configs/translation.json \
|
||||||
|
--train_dataset=translation \
|
||||||
|
--train_config=configs/base.yaml \
|
||||||
|
--output_dir=results/checkpoints/translation
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
## Contact and Support
|
## Contact and Support
|
||||||
For bug reports, feature requests, and general inquiries, please open an issue on our GitHub Issues page. Make sure to include as much detail as possible to help us address your issue quickly.
|
For bug reports, feature requests, and general inquiries, please open an issue on our GitHub Issues page. Make sure to include as much detail as possible to help us address your issue quickly.
|
||||||
|
|
||||||
## 🌟Todo list
|
## 🌟Todo list
|
||||||
- ☑️ 📝 Update models, evaluation scripts, and expert selection scripts
|
- ☑️ 📝 Update models, evaluation scripts, and expert selection scripts
|
||||||
- 🔲 🔧 Update training scripts
|
- ☑️ 🔧 Update training scripts
|
||||||
- 🔲 🚀 More...
|
- 🔲 🚀 More...
|
||||||
|
|
||||||
|
|
||||||
|
31
configs/base.yaml
Normal file
31
configs/base.yaml
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
seed: 5934875
|
||||||
|
# Model settings
|
||||||
|
seq_length: 4096 # Maximum sequence length
|
||||||
|
|
||||||
|
# Data settings
|
||||||
|
per_device_batch_size: 1
|
||||||
|
n_device: 8 # Number of devices
|
||||||
|
|
||||||
|
# Training settings
|
||||||
|
optim: adamw_torch_fused
|
||||||
|
steps: 500 # Number of training steps
|
||||||
|
learning_rate: 0.00001 # Learning rate
|
||||||
|
weight_decay: 0.1 # Weight decay for optimizer
|
||||||
|
warmup_steps: 0 # Number of warmup steps for learning rate scheduler
|
||||||
|
logging_steps: 10 # Log every X steps
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.95
|
||||||
|
random_concat_ratio: 0.2 # Ratio of random concatenation
|
||||||
|
|
||||||
|
|
||||||
|
# Evaluation settings
|
||||||
|
eval_steps: 100 # Evaluate every X steps
|
||||||
|
save_steps: 100 # Save model every X steps
|
||||||
|
|
||||||
|
# Tokenizer settings
|
||||||
|
|
||||||
|
# Additional settings (if needed)
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_accumulation_steps: 16 # Number of updates steps to accumulate before performing a backward/update pass
|
||||||
|
max_grad_norm: 1.0 # Max gradient norm for gradient clipping
|
||||||
|
ep_size: 2
|
0
deepseek/__init__.py
Normal file
0
deepseek/__init__.py
Normal file
206
deepseek/configuration_deepseek.py
Normal file
206
deepseek/configuration_deepseek.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||||
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the DeepSeek-V2.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 102400):
|
||||||
|
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`DeepseekV2Model`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
||||||
|
Dimension of the MoE representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
n_shared_experts (`int`, *optional*, defaults to None):
|
||||||
|
Number of shared experts, None means dense model.
|
||||||
|
n_routed_experts (`int`, *optional*, defaults to None):
|
||||||
|
Number of routed experts, None means dense model.
|
||||||
|
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||||
|
Scaling factor or routed experts.
|
||||||
|
topk_method (`str`, *optional*, defaults to `gready`):
|
||||||
|
Topk method used in routed gate.
|
||||||
|
n_group (`int`, *optional*, defaults to None):
|
||||||
|
Number of groups for routed experts.
|
||||||
|
topk_group (`int`, *optional*, defaults to None):
|
||||||
|
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||||
|
num_experts_per_tok (`int`, *optional*, defaults to None):
|
||||||
|
Number of selected experts, None means dense model.
|
||||||
|
moe_layer_freq (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
||||||
|
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||||
|
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||||
|
\--k dense layers--/
|
||||||
|
norm_topk_prob (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to normalize the weights of the routed experts.
|
||||||
|
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
||||||
|
Method of computing expert weights.
|
||||||
|
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||||
|
Auxiliary loss weight coefficient.
|
||||||
|
seq_aux = (`bool`, *optional*, defaults to True):
|
||||||
|
Whether to compute the auxiliary loss for each individual sample.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
Padding token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
Beginning of stream token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
End of stream token id.
|
||||||
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
|
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||||
|
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||||
|
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||||
|
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||||
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import DeepseekV2Model, DeepseekV2Config
|
||||||
|
|
||||||
|
>>> # Initializing a Deepseek-V2 style configuration
|
||||||
|
>>> configuration = DeepseekV2Config()
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "deepseek_v2"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size = 1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts = None,
|
||||||
|
n_routed_experts = None,
|
||||||
|
ep_size = 1,
|
||||||
|
routed_scaling_factor = 1.0,
|
||||||
|
kv_lora_rank = 512,
|
||||||
|
q_lora_rank = 1536,
|
||||||
|
qk_rope_head_dim = 64,
|
||||||
|
v_head_dim = 128,
|
||||||
|
qk_nope_head_dim = 128,
|
||||||
|
topk_method = 'gready',
|
||||||
|
n_group = None,
|
||||||
|
topk_group = None,
|
||||||
|
num_experts_per_tok = None,
|
||||||
|
moe_layer_freq = 1,
|
||||||
|
first_k_dense_replace = 0,
|
||||||
|
norm_topk_prob = False,
|
||||||
|
scoring_func = 'softmax',
|
||||||
|
aux_loss_alpha = 0.001,
|
||||||
|
seq_aux = True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
1918
deepseek/modeling_deepseek.py
Normal file
1918
deepseek/modeling_deepseek.py
Normal file
File diff suppressed because it is too large
Load Diff
33
esft.py
33
esft.py
@ -7,6 +7,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
|
|
||||||
def to_buffer(module, mark_param=True):
|
def to_buffer(module, mark_param=True):
|
||||||
"""Turns all parameters of a module into buffers."""
|
"""Turns all parameters of a module into buffers."""
|
||||||
|
if module is None:
|
||||||
|
return
|
||||||
modules = module.modules()
|
modules = module.modules()
|
||||||
module = next(modules)
|
module = next(modules)
|
||||||
delattrs = []
|
delattrs = []
|
||||||
@ -25,6 +27,8 @@ def to_buffer(module, mark_param=True):
|
|||||||
|
|
||||||
def to_param(module):
|
def to_param(module):
|
||||||
"""Turns all buffers of a module into parameterss."""
|
"""Turns all buffers of a module into parameterss."""
|
||||||
|
if module is None:
|
||||||
|
return
|
||||||
modules = module.modules()
|
modules = module.modules()
|
||||||
module = next(modules)
|
module = next(modules)
|
||||||
param_list = getattr(module, 'param_list', [])
|
param_list = getattr(module, 'param_list', [])
|
||||||
@ -57,7 +61,7 @@ def to_esft(model, adapter_config):
|
|||||||
to_buffer(model)
|
to_buffer(model)
|
||||||
else:
|
else:
|
||||||
to_param(model)
|
to_param(model)
|
||||||
for idx, layer in enumerate(model.layers):
|
for idx, layer in enumerate(model.model.layers):
|
||||||
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
||||||
continue
|
continue
|
||||||
if adapter_config.get('shared_experts', False):
|
if adapter_config.get('shared_experts', False):
|
||||||
@ -72,15 +76,25 @@ def to_esft(model, adapter_config):
|
|||||||
to_buffer(layer.mlp.experts[expert_id])
|
to_buffer(layer.mlp.experts[expert_id])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(folder_path):
|
def load_state_dict(folder_path):
|
||||||
|
# 初始化空的 state_dict
|
||||||
combined_state_dict = {}
|
combined_state_dict = {}
|
||||||
|
|
||||||
|
# 遍历文件夹中的所有文件
|
||||||
for file_name in os.listdir(folder_path):
|
for file_name in os.listdir(folder_path):
|
||||||
if file_name.endswith('.safetensors'):
|
if file_name.endswith('.safetensors'):
|
||||||
file_path = os.path.join(folder_path, file_name)
|
file_path = os.path.join(folder_path, file_name)
|
||||||
state_dict = load_file(file_path)
|
state_dict = load_file(file_path)
|
||||||
combined_state_dict.update(state_dict)
|
combined_state_dict.update(state_dict)
|
||||||
|
|
||||||
|
# legacy for loading v1 checkpoints: add prefix "model." for parameters
|
||||||
|
for k in list(combined_state_dict.keys()):
|
||||||
|
if k.startswith("layers"):
|
||||||
|
k_new = "model." + k
|
||||||
|
combined_state_dict[k_new] = combined_state_dict[k]
|
||||||
|
del combined_state_dict[k]
|
||||||
|
|
||||||
return combined_state_dict
|
return combined_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -89,21 +103,24 @@ def load_esft_model(base_model_path, adapter_dir):
|
|||||||
adapter_state_dict = load_state_dict(adapter_dir)
|
adapter_state_dict = load_state_dict(adapter_dir)
|
||||||
|
|
||||||
# load pretrained model:
|
# load pretrained model:
|
||||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path)
|
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
|
||||||
to_esft(model.model, adapter_config)
|
to_esft(model, adapter_config)
|
||||||
model.model.load_state_dict(adapter_state_dict)
|
model.load_state_dict(adapter_state_dict)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def load_base_model(base_model_path):
|
def load_base_model(base_model_path):
|
||||||
# load pretrained model:
|
# load pretrained model:
|
||||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path)
|
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def add_adapter(base_model, adapter_dir, return_original_states=False):
|
def add_adapter(base_model, adapter_dir, return_original_states=False, expert_config=None):
|
||||||
adapter_config = json.load(open(adapter_dir + "/expert_cfg.json"))
|
if expert_config is not None:
|
||||||
|
adapter_config = json.load(open(expert_config))
|
||||||
|
else:
|
||||||
|
adapter_config = json.load(open(adapter_dir + "/expert_cfg.json"))
|
||||||
adapter_state_dict = load_state_dict(adapter_dir)
|
adapter_state_dict = load_state_dict(adapter_dir)
|
||||||
|
|
||||||
to_esft(base_model, adapter_config)
|
to_esft(base_model, adapter_config)
|
||||||
|
92
eval_multigpu.py
Normal file
92
eval_multigpu.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from torch import device
|
||||||
|
from benchmarks import *
|
||||||
|
import os
|
||||||
|
from esft import load_base_model, add_adapter
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from itertools import accumulate
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
def infer_auto_device_map(model, pp_splits, visible_devices):
|
||||||
|
assert len(pp_splits) == len(visible_devices)
|
||||||
|
device_map = {
|
||||||
|
"model.embed_tokens": 0,
|
||||||
|
"model.norm": len(pp_splits) - 1,
|
||||||
|
"lm_head": len(pp_splits) - 1
|
||||||
|
}
|
||||||
|
assert len(model.model.layers) == sum(pp_splits)
|
||||||
|
pp_splits = [0, *list(accumulate(pp_splits))]
|
||||||
|
for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])):
|
||||||
|
for i in range(start, end):
|
||||||
|
device_map.update({f"model.layers.{i}": idx})
|
||||||
|
for k, v in device_map.items():
|
||||||
|
device_map[k] = visible_devices[v]
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
|
||||||
|
def eval_model(rank, args, model, dataset):
|
||||||
|
config = {
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"eval_batch_size": args.eval_batch_size,
|
||||||
|
"openai_api_key": args.openai_api_key
|
||||||
|
}
|
||||||
|
evaluator_map = {
|
||||||
|
"intent": IntentEvaluator,
|
||||||
|
"summary": SummaryEvaluator,
|
||||||
|
"law": LawEvaluator,
|
||||||
|
"translation": TranslationEvaluator
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
evaluator_cls = evaluator_map[args.eval_dataset]
|
||||||
|
print(f"Rank {rank} starting evaluation...", flush=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
|
||||||
|
visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank))
|
||||||
|
device_map = infer_auto_device_map(model, [14, 13], visible_devices)
|
||||||
|
model = dispatch_model(model, device_map)
|
||||||
|
cur_dataset = dataset[rank::args.world_size]
|
||||||
|
evaluator = evaluator_cls(cur_dataset, config)
|
||||||
|
with torch.no_grad():
|
||||||
|
results, metrics = evaluator.evaluate(model, tokenizer)
|
||||||
|
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
||||||
|
with open(args.output_path + f".rank_{rank}", "w") as f:
|
||||||
|
for res, m in zip(results, metrics):
|
||||||
|
obj = {
|
||||||
|
"example": res,
|
||||||
|
"score": m
|
||||||
|
}
|
||||||
|
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in process {rank}: {e}", flush=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.")
|
||||||
|
parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset")
|
||||||
|
parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model")
|
||||||
|
parser.add_argument("--adapter_dir", type=str, required=True, help="Directory containing the adapter")
|
||||||
|
parser.add_argument("--output_path", type=str, required=True, help="Path to save the evaluation results")
|
||||||
|
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens")
|
||||||
|
parser.add_argument("--openai_api_key", type=str, required=True, help="API key for OpenAI")
|
||||||
|
parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation")
|
||||||
|
parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation")
|
||||||
|
parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print("Loading base model...")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock
|
||||||
|
|
||||||
|
print(f"Running evaluation on {args.eval_dataset}...")
|
||||||
|
dataset = [json.loads(i) for i in open(f"datasets/eval/{args.eval_dataset}.jsonl").readlines()]
|
||||||
|
|
||||||
|
print("Adding adapter...")
|
||||||
|
model = add_adapter(model, args.adapter_dir, return_original_states=False)
|
||||||
|
|
||||||
|
print("Start Evaluating...")
|
||||||
|
mp.spawn(eval_model, args=(args, model, dataset), nprocs=args.world_size, join=True)
|
@ -1,12 +1,24 @@
|
|||||||
# first, download adapter models and put them to the corresponding directories
|
# first: download adapter models and put them to the corresponding directories
|
||||||
|
|
||||||
|
|
||||||
python scripts/eval.py \
|
python eval_multigpu.py \
|
||||||
--eval_datasets=translation \
|
--eval_datasets=translation \
|
||||||
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
--adapter_dir=all_models/adapters/token \
|
--adapter_dir=all_models/adapters/token \
|
||||||
--output_dir=results/completions/token \
|
--output_dir=results/completions/token \
|
||||||
--max_new_tokens=512 \
|
--max_new_tokens=512 \
|
||||||
--openai_api_key=REPLACE_WITH_YOUR_KEY \
|
--openai_api_key=REPLACE_WITH_YOUR_KEY \
|
||||||
--eval_batch_size=2
|
--eval_batch_size=2 \
|
||||||
|
--world_size=4 \
|
||||||
|
--gpus_per_rank=2
|
||||||
|
|
||||||
|
# this script is used for single-gpu training and has been deprecated. If you have no multiple gpus, you can set above world_size=1 and gpus_per_rank=1
|
||||||
|
|
||||||
|
# python scripts/eval.py \
|
||||||
|
# --eval_datasets=translation \
|
||||||
|
# --base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
|
# --adapter_dir=all_models/adapters/token \
|
||||||
|
# --output_dir=results/completions/token \
|
||||||
|
# --max_new_tokens=512 \
|
||||||
|
# --openai_api_key=REPLACE_WITH_YOUR_KEY \
|
||||||
|
# --eval_batch_size=2
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
python scripts/get_expert_scores.py \
|
python scripts/expert/get_expert_scores.py \
|
||||||
--eval_datasets=intent,summary,law,translation \
|
--eval_dataset=translation \
|
||||||
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
--output_dir=results/expert_scores \
|
--output_dir=results/expert_scores/translation \
|
||||||
--n_sample_tokens=8192 # this sample size is a hyperparameter
|
--n_sample_tokens=131072 \
|
||||||
|
--world_size=4 \
|
||||||
|
--gpus_per_rank=2
|
||||||
|
|
||||||
python scripts/generate_expert_config.py \
|
python scripts/expert/generate_expert_config.py \
|
||||||
--eval_datasets=intent,summary,law,translation \
|
--eval_datasets=intent,summary,law,translation \
|
||||||
--expert_scores_dir=results/expert_scores \
|
--expert_scores_dir=results/expert_scores \
|
||||||
--output_dir=results/expert_configs \
|
--output_dir=results/expert_configs \
|
97
scripts/expert/generate_expert_config.py
Normal file
97
scripts/expert/generate_expert_config.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from multiprocessing import Pool
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def parse_line(line):
|
||||||
|
expert_ids, expert_weights = line.split("\t\t")
|
||||||
|
expert_ids = [int(i) for i in expert_ids.split("\t")]
|
||||||
|
expert_weights = [float(i) for i in expert_weights.split("\t")]
|
||||||
|
return expert_ids, expert_weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_summary(files):
|
||||||
|
TOP_K=6
|
||||||
|
N_EXPERTS=64
|
||||||
|
N_LAYERS=26 # 27 layers totally, the first layer is not MoE
|
||||||
|
|
||||||
|
gate_scores = np.zeros((N_LAYERS, N_EXPERTS))
|
||||||
|
token_scores = np.zeros((N_LAYERS, N_EXPERTS))
|
||||||
|
|
||||||
|
print("loading files")
|
||||||
|
for rank, file in files:
|
||||||
|
layer_id = int(file.split(".")[0].split("_")[2]) - 1
|
||||||
|
|
||||||
|
with open(os.path.join(args.expert_scores_dir, rank, file)) as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
expert_ids, expert_weights = parse_line(line)
|
||||||
|
np.add.at(gate_scores[layer_id], expert_ids, expert_weights)
|
||||||
|
np.add.at(token_scores[layer_id], expert_ids, np.ones_like(expert_weights) / TOP_K)
|
||||||
|
|
||||||
|
gate_scores = gate_scores / np.sum(gate_scores, axis=0)
|
||||||
|
token_scores = token_scores / np.sum(token_scores, axis=0)
|
||||||
|
|
||||||
|
summary = {"token_scores": token_scores, "gate_scores": gate_scores}
|
||||||
|
summary = {k: {str(i+1): {str(j): round(v, 4) for j, v in enumerate(l)} for i, l in enumerate(v)} for k, v in summary.items()}
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--eval_dataset", type=str, required=True)
|
||||||
|
parser.add_argument("--expert_scores_dir", type=str, required=True)
|
||||||
|
parser.add_argument("--output_path", type=str, required=True)
|
||||||
|
parser.add_argument("--score_function", type=str, required=True)
|
||||||
|
parser.add_argument("--top_p", type=float, required=True)
|
||||||
|
parser.add_argument("--train_shared_experts", action="store_true")
|
||||||
|
parser.add_argument("--train_non_expert_modules", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
expert_cfg = { # initialize expert config
|
||||||
|
"experts": {},
|
||||||
|
"shared_experts": args.train_shared_experts,
|
||||||
|
"non_expert_modules": args.train_non_expert_modules
|
||||||
|
}
|
||||||
|
|
||||||
|
# let's walk inside args.expert_scores_dir and get abs file names
|
||||||
|
file_names = []
|
||||||
|
for rank in [i for i in os.listdir(args.expert_scores_dir) if 'rank' in i]:
|
||||||
|
for file in os.listdir(os.path.join(args.expert_scores_dir, rank)):
|
||||||
|
file_names.append([rank, file])
|
||||||
|
|
||||||
|
|
||||||
|
summary_file = os.path.join(args.expert_scores_dir, "summary.json")
|
||||||
|
summary = get_summary(file_names)
|
||||||
|
|
||||||
|
with open(summary_file, "w") as f:
|
||||||
|
f.write(json.dumps(summary))
|
||||||
|
|
||||||
|
|
||||||
|
scores = summary[f"{args.score_function}_scores"]
|
||||||
|
for layer, l_score in scores.items():
|
||||||
|
l_score = [(int(k), v) for k,v in l_score.items()]
|
||||||
|
l_score = sorted(l_score, key=lambda x: x[1], reverse=True)
|
||||||
|
selected_experts = []
|
||||||
|
current_score = 0
|
||||||
|
for expert, score in l_score:
|
||||||
|
if current_score >= args.top_p:
|
||||||
|
break
|
||||||
|
selected_experts.append(expert)
|
||||||
|
current_score += score
|
||||||
|
expert_cfg["experts"][layer] = selected_experts
|
||||||
|
|
||||||
|
top_p = args.top_p
|
||||||
|
train_shared_experts = args.train_shared_experts
|
||||||
|
train_non_expert_modules = args.train_non_expert_modules
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
||||||
|
with open(args.output_path, "w") as f:
|
||||||
|
json.dump(expert_cfg, f)
|
78
scripts/expert/get_expert_scores.py
Normal file
78
scripts/expert/get_expert_scores.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from utils import get_formatted_input_and_target
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from itertools import accumulate
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
|
|
||||||
|
def infer_auto_device_map(model, pp_splits, visible_devices):
|
||||||
|
assert len(pp_splits) == len(visible_devices)
|
||||||
|
device_map = {
|
||||||
|
"model.embed_tokens": 0,
|
||||||
|
"model.norm": len(pp_splits) - 1,
|
||||||
|
"lm_head": len(pp_splits) - 1
|
||||||
|
}
|
||||||
|
assert len(model.model.layers) == sum(pp_splits)
|
||||||
|
pp_splits = [0, *list(accumulate(pp_splits))]
|
||||||
|
for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])):
|
||||||
|
for i in range(start, end):
|
||||||
|
device_map.update({f"model.layers.{i}": idx})
|
||||||
|
for k, v in device_map.items():
|
||||||
|
device_map[k] = visible_devices[v]
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
|
||||||
|
def eval_expert(rank, args, model, dataset):
|
||||||
|
try:
|
||||||
|
print(f"Rank {rank} starting expert evaluation...", flush=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
|
||||||
|
visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank))
|
||||||
|
device_map = infer_auto_device_map(model, [14, 13], visible_devices)
|
||||||
|
model = dispatch_model(model, device_map)
|
||||||
|
model.config.expert_log_dir = os.path.join(args.output_dir, f"rank_{rank}")
|
||||||
|
n_sample_tokens = args.n_sample_tokens // args.world_size
|
||||||
|
os.makedirs(os.path.join(args.output_dir, f"rank_{rank}"), exist_ok=True)
|
||||||
|
done_tokens = 0
|
||||||
|
cur_dataset = dataset[rank::args.world_size]
|
||||||
|
for instance in cur_dataset:
|
||||||
|
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
|
||||||
|
model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0))
|
||||||
|
done_tokens += len(input_ids)
|
||||||
|
if done_tokens >= n_sample_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in process {rank}: {e}", flush=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.")
|
||||||
|
parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset")
|
||||||
|
parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model")
|
||||||
|
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the evaluation results")
|
||||||
|
parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation")
|
||||||
|
parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process")
|
||||||
|
parser.add_argument("--n_sample_tokens", type=int, required=True, help="Token to sample for expert evaluation")
|
||||||
|
args = parser.parse_args()
|
||||||
|
random.seed(5934875)
|
||||||
|
|
||||||
|
|
||||||
|
print("Loading base model...")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock
|
||||||
|
model.config.log_expert_weights = True
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Running expert evaluation on {args.eval_dataset}...")
|
||||||
|
dataset = [json.loads(i) for i in open(f"datasets/train/{args.eval_dataset}.jsonl").readlines()]
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
print("Start Evaluating...")
|
||||||
|
mp.spawn(eval_expert, args=(args, model, dataset), nprocs=args.world_size, join=True)
|
@ -1,50 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--eval_datasets", type=str, required=True)
|
|
||||||
parser.add_argument("--expert_scores_dir", type=str, required=True)
|
|
||||||
parser.add_argument("--output_dir", type=str, required=True)
|
|
||||||
parser.add_argument("--score_function", type=str, required=True)
|
|
||||||
parser.add_argument("--top_p", type=float, required=True)
|
|
||||||
parser.add_argument("--train_shared_experts", action="store_true")
|
|
||||||
parser.add_argument("--train_non_expert_modules", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
eval_datasets = args.eval_datasets.split(",")
|
|
||||||
expert_scores_dir = args.expert_scores_dir
|
|
||||||
output_dir = args.output_dir
|
|
||||||
score_function = args.score_function
|
|
||||||
top_p = args.top_p
|
|
||||||
train_shared_experts = args.train_shared_experts
|
|
||||||
train_non_expert_modules = args.train_non_expert_modules
|
|
||||||
|
|
||||||
for dataset_name in eval_datasets:
|
|
||||||
summary_file = f"{expert_scores_dir}/{dataset_name}/summary.json"
|
|
||||||
expert_cfg = {"experts": {}, "shared_experts": train_shared_experts, "non_expert_modules": train_non_expert_modules}
|
|
||||||
|
|
||||||
with open(summary_file) as f:
|
|
||||||
data = json.load(f)
|
|
||||||
assert score_function in ["gate", "token"], f"Unknown score function: {score_function}"
|
|
||||||
scores = data[f"{score_function}_scores"]
|
|
||||||
|
|
||||||
for layer, l_score in scores.items():
|
|
||||||
l_score = [(int(k), v) for k,v in l_score.items()]
|
|
||||||
l_score = sorted(l_score, key=lambda x: x[1], reverse=True)
|
|
||||||
# get the top experts that make the threshold exceed top_p
|
|
||||||
selected_experts = []
|
|
||||||
current_score = 0
|
|
||||||
for expert, score in l_score:
|
|
||||||
if current_score >= top_p:
|
|
||||||
break
|
|
||||||
selected_experts.append(expert)
|
|
||||||
current_score += score
|
|
||||||
expert_cfg["experts"][layer] = selected_experts
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
with open(f"{output_dir}/{dataset_name}.json", "w") as f:
|
|
||||||
json.dump(expert_cfg, f)
|
|
||||||
|
|
||||||
|
|
@ -1,75 +0,0 @@
|
|||||||
import json
|
|
||||||
from benchmarks import *
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import argparse
|
|
||||||
from random import shuffle
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from utils import get_formatted_input_and_target
|
|
||||||
|
|
||||||
# constants for deepseek-v2-lite
|
|
||||||
TOP_K=6
|
|
||||||
N_EXPERTS=64
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--base_model_path", type=str, required=True)
|
|
||||||
parser.add_argument("--eval_datasets", type=str, required=True)
|
|
||||||
parser.add_argument("--output_dir", type=str, required=True)
|
|
||||||
parser.add_argument("--n_sample_tokens", type=int, required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
eval_datasets = args.eval_datasets.split(",")
|
|
||||||
output_dir = args.output_dir
|
|
||||||
base_model_path = args.base_model_path
|
|
||||||
n_sample_tokens = args.n_sample_tokens
|
|
||||||
|
|
||||||
model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path)
|
|
||||||
model.config.log_expert_weights = True
|
|
||||||
|
|
||||||
for dataset_name in eval_datasets:
|
|
||||||
dataset = [json.loads(i) for i in open(f"datasets/train/{dataset_name}.jsonl").readlines()]
|
|
||||||
shuffle(dataset)
|
|
||||||
model.config.expert_log_dir = os.path.join(args.output_dir, dataset_name)
|
|
||||||
# make dir -p this
|
|
||||||
os.makedirs(os.path.join(args.output_dir, dataset_name), exist_ok=True)
|
|
||||||
done_tokens = 0
|
|
||||||
for instance in dataset:
|
|
||||||
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
|
|
||||||
model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0))
|
|
||||||
done_tokens += len(input_ids)
|
|
||||||
if done_tokens >= n_sample_tokens:
|
|
||||||
break
|
|
||||||
|
|
||||||
# open all files under os.path.join(args.output_dir, dataset_name). For each file, generate a summary of it
|
|
||||||
# and write it to a file in the same directory
|
|
||||||
files = os.listdir(os.path.join(args.output_dir, dataset_name))
|
|
||||||
summary_file = os.path.join(args.output_dir, dataset_name, "summary.json")
|
|
||||||
token_scores = {}
|
|
||||||
gate_scores = {}
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if not file.endswith(".txt"):
|
|
||||||
continue
|
|
||||||
layer_idx = file.split("_")[2].split(".")[0]
|
|
||||||
token_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)}
|
|
||||||
gate_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)}
|
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, dataset_name, file)) as f:
|
|
||||||
data = f.readlines()
|
|
||||||
for line in data:
|
|
||||||
expert_ids, expert_weights = line.split("\t\t")
|
|
||||||
expert_ids = [int(i) for i in expert_ids.split("\t")]
|
|
||||||
expert_weights = [float(i) for i in expert_weights.split("\t")]
|
|
||||||
for expert_id, expert_weight in zip(expert_ids, expert_weights):
|
|
||||||
gate_scores[layer_idx][expert_id] += expert_weight
|
|
||||||
token_scores[layer_idx][expert_id] += 1. / TOP_K
|
|
||||||
total = sum(token_scores[layer_idx].values())
|
|
||||||
gate_scores[layer_idx] = {expert: round(gate_scores[layer_idx][expert] / total, 4) for expert in gate_scores[layer_idx]}
|
|
||||||
token_scores[layer_idx] = {expert: round(token_scores[layer_idx][expert] / total, 4) for expert in token_scores[layer_idx]}
|
|
||||||
|
|
||||||
|
|
||||||
with open(summary_file, "w") as f:
|
|
||||||
f.write(json.dumps({"token_scores": token_scores, "gate_scores": gate_scores}))
|
|
||||||
|
|
||||||
|
|
12
scripts/train.sh
Normal file
12
scripts/train.sh
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
|
||||||
|
export TOKENIZERS_PARALLELISM=false
|
||||||
|
|
||||||
|
exp_name="test/eval_translation"
|
||||||
|
base_model_path="deepseek-ai/esft-vanilla-lite"
|
||||||
|
# turn above to for loop
|
||||||
|
python train.py \
|
||||||
|
--base_model_path=${base_model_path} \
|
||||||
|
--expert_config=results/expert_configs/translation.json \
|
||||||
|
--train_dataset=translation \
|
||||||
|
--train_config=configs/base.yaml \
|
||||||
|
--output_dir=results/checkpoints/${exp_name}
|
11
scripts/train_ep.sh
Normal file
11
scripts/train_ep.sh
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
|
||||||
|
export TOKENIZERS_PARALLELISM=false
|
||||||
|
|
||||||
|
exp_name="test/eval_translation"
|
||||||
|
base_model_path="deepseek-ai/esft-vanilla-lite"
|
||||||
|
torchrun --nproc-per-node=8 train_ep.py \
|
||||||
|
--base_model_path=${base_model_path} \
|
||||||
|
--expert_config=results/expert_configs/translation.json \
|
||||||
|
--train_dataset=translation \
|
||||||
|
--train_config=configs/base.yaml \
|
||||||
|
--output_dir=results/checkpoints/${exp_name}
|
117
train.py
Normal file
117
train.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging
|
||||||
|
|
||||||
|
from benchmarks import *
|
||||||
|
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad
|
||||||
|
from esft import to_esft
|
||||||
|
from deepseek.modeling_deepseek import DeepseekV2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--base_model_path", type=str, required=True)
|
||||||
|
parser.add_argument("--expert_config", type=str, required=True)
|
||||||
|
parser.add_argument("--train_dataset", type=str, required=True)
|
||||||
|
parser.add_argument("--output_dir", type=str, required=True)
|
||||||
|
parser.add_argument("--train_config", type=str, required=True)
|
||||||
|
parser.add_argument("--wandb_api_key", type=str, required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
expert_config = json.load(open(args.expert_config))
|
||||||
|
output_dir = args.output_dir
|
||||||
|
base_model_path = args.base_model_path
|
||||||
|
config = yaml.safe_load(open(args.train_config))
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
seed = config['seed']
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
if args.wandb_api_key is not None:
|
||||||
|
import wandb
|
||||||
|
wandb.login(key=args.wandb_api_key)
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()]
|
||||||
|
buffer = []
|
||||||
|
for instance in samples:
|
||||||
|
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
|
||||||
|
buffer.append((input_ids, target_ids))
|
||||||
|
seq_length = config['seq_length']
|
||||||
|
random_concat_ratio = config['random_concat_ratio']
|
||||||
|
concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
|
||||||
|
|
||||||
|
dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
|
||||||
|
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
max_steps=config['steps'],
|
||||||
|
per_device_train_batch_size=config['per_device_batch_size'],
|
||||||
|
per_device_eval_batch_size=config['per_device_batch_size'],
|
||||||
|
warmup_steps=config['warmup_steps'],
|
||||||
|
weight_decay=config['weight_decay'],
|
||||||
|
logging_dir=f"{output_dir}/logs",
|
||||||
|
logging_steps=config['logging_steps'],
|
||||||
|
save_steps=config['save_steps'],
|
||||||
|
eval_strategy="steps",
|
||||||
|
eval_steps=config['eval_steps'],
|
||||||
|
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="loss",
|
||||||
|
greater_is_better=False,
|
||||||
|
bf16=True,
|
||||||
|
lr_scheduler_type='constant',
|
||||||
|
save_total_limit=5,
|
||||||
|
learning_rate=config['learning_rate'],
|
||||||
|
optim=config['optim'],
|
||||||
|
adam_beta1=config['adam_beta1'],
|
||||||
|
adam_beta2=config['adam_beta2'],
|
||||||
|
gradient_checkpointing=config['gradient_checkpointing'],
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_collator(data):
|
||||||
|
input_ids = torch.stack([item[0] for item in data])
|
||||||
|
labels = torch.stack([item[1] for item in data])
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
|
|
||||||
|
|
||||||
|
model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
|
||||||
|
to_esft(model, expert_config)
|
||||||
|
|
||||||
|
# Initialize Trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=valid_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
# Training
|
||||||
|
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already
|
||||||
|
trainer.train(resume_from_checkpoint=True)
|
||||||
|
else:
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Save the model and tokenizer
|
||||||
|
trainer.save_model(output_dir + "/last_checkpoint")
|
||||||
|
tokenizer.save_pretrained(output_dir + "/last_checkpoint")
|
||||||
|
|
||||||
|
print("Training complete")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
154
train_ep.py
Normal file
154
train_ep.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from types import MethodType
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging
|
||||||
|
|
||||||
|
from benchmarks import *
|
||||||
|
from utils import get_formatted_input_and_target, get_examples_from_buffer_pad, init_parallel_groups
|
||||||
|
from esft import to_esft
|
||||||
|
from deepseek.modeling_deepseek import DeepseekV2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
os.environ["NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--base_model_path", type=str, required=True)
|
||||||
|
parser.add_argument("--expert_config", type=str, required=True)
|
||||||
|
parser.add_argument("--train_dataset", type=str, required=True)
|
||||||
|
parser.add_argument("--output_dir", type=str, required=True)
|
||||||
|
parser.add_argument("--train_config", type=str, required=True)
|
||||||
|
parser.add_argument("--wandb_api_key", type=str, required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
expert_config = json.load(open(args.expert_config))
|
||||||
|
output_dir = args.output_dir
|
||||||
|
base_model_path = args.base_model_path
|
||||||
|
config = yaml.safe_load(open(args.train_config))
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
seed = config['seed']
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
if args.wandb_api_key is not None:
|
||||||
|
import wandb
|
||||||
|
wandb.login(key=args.wandb_api_key)
|
||||||
|
|
||||||
|
ep_size = config.get("ep_size", 1)
|
||||||
|
world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size)
|
||||||
|
edp_size = world_size // ep_size
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()]
|
||||||
|
buffer = []
|
||||||
|
for instance in samples:
|
||||||
|
input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100)
|
||||||
|
buffer.append((input_ids, target_ids))
|
||||||
|
seq_length = config['seq_length']
|
||||||
|
random_concat_ratio = config['random_concat_ratio']
|
||||||
|
concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio)
|
||||||
|
|
||||||
|
dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels'])
|
||||||
|
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)])
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
max_steps=config['steps'],
|
||||||
|
per_device_train_batch_size=config['per_device_batch_size'],
|
||||||
|
per_device_eval_batch_size=config['per_device_batch_size'],
|
||||||
|
warmup_steps=config['warmup_steps'],
|
||||||
|
weight_decay=config['weight_decay'],
|
||||||
|
logging_dir=f"{output_dir}/logs",
|
||||||
|
logging_steps=config['logging_steps'],
|
||||||
|
save_steps=config['save_steps'],
|
||||||
|
eval_strategy="steps",
|
||||||
|
eval_steps=config['eval_steps'],
|
||||||
|
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="loss",
|
||||||
|
greater_is_better=False,
|
||||||
|
bf16=True,
|
||||||
|
lr_scheduler_type='constant',
|
||||||
|
save_total_limit=5,
|
||||||
|
learning_rate=config['learning_rate'],
|
||||||
|
optim=config['optim'],
|
||||||
|
adam_beta1=config['adam_beta1'],
|
||||||
|
adam_beta2=config['adam_beta2'],
|
||||||
|
disable_tqdm=False,
|
||||||
|
gradient_checkpointing=config['gradient_checkpointing'],
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_collator(data):
|
||||||
|
input_ids = torch.stack([item[0] for item in data])
|
||||||
|
labels = torch.stack([item[1] for item in data])
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
|
|
||||||
|
|
||||||
|
model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16,
|
||||||
|
ep_size=ep_size, attn_implementation="flash_attention_2")
|
||||||
|
model._ddp_params_and_buffers_to_ignore = [n for n, _ in model.named_parameters() if ".expert" in n] # we manage grad synchronization of expert parameters
|
||||||
|
to_esft(model, expert_config)
|
||||||
|
model.dummy = torch.nn.Parameter(torch.zeros(1, dtype=model.dtype)) # prevent DDP from having no trainable parameters
|
||||||
|
model._keys_to_ignore_on_save = ["dummy"]
|
||||||
|
expert_params = [p for n, p in model.named_parameters() if p.requires_grad and ".expert" in n]
|
||||||
|
for layer in model.model.layers:
|
||||||
|
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
||||||
|
continue
|
||||||
|
layer.mlp.ep_group = ep_group
|
||||||
|
|
||||||
|
# Initialize Trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=valid_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator = trainer.accelerator
|
||||||
|
backward = accelerator.backward
|
||||||
|
def custom_backward(self, loss, **kwargs):
|
||||||
|
backward(loss, **kwargs)
|
||||||
|
if not self.sync_gradients or edp_size == 1:
|
||||||
|
return
|
||||||
|
return
|
||||||
|
for p in expert_params:
|
||||||
|
g = p.grad if p.grad is not None else torch.zeros_like(p)
|
||||||
|
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
|
||||||
|
if p.grad is not g:
|
||||||
|
p.grad = g
|
||||||
|
accelerator.backward = MethodType(custom_backward, accelerator)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
ckpt_path = f"{output_dir}/last_checkpoint_ep{local_rank}"
|
||||||
|
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already
|
||||||
|
trainer.train(resume_from_checkpoint=ckpt_path)
|
||||||
|
else:
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Save the model and tokenizer
|
||||||
|
if local_rank == 0:
|
||||||
|
trainer.save_model(ckpt_path)
|
||||||
|
tokenizer.save_pretrained(ckpt_path)
|
||||||
|
elif 0 < local_rank < ep_size:
|
||||||
|
model.save_pretrained(ckpt_path)
|
||||||
|
|
||||||
|
print("Training complete")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
58
utils.py
58
utils.py
@ -1,3 +1,7 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
# given a message object, convert to prompt and response
|
# given a message object, convert to prompt and response
|
||||||
|
|
||||||
PROMPT_USER: str = 'User: {input}\n\n'
|
PROMPT_USER: str = 'User: {input}\n\n'
|
||||||
@ -38,3 +42,57 @@ def get_formatted_input_and_target(messages, tokenizer, IGNORE_TOKEN_ID=-100, ma
|
|||||||
assert False, f"Unknown role: {message['role']}"
|
assert False, f"Unknown role: {message['role']}"
|
||||||
|
|
||||||
return [input_ids, target_ids]
|
return [input_ids, target_ids]
|
||||||
|
|
||||||
|
|
||||||
|
def get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio, IGNORE_TOKEN_ID=-100):
|
||||||
|
all_input_ids_list, all_target_ids_list = [], []
|
||||||
|
all_input_ids, all_target_ids = [], []
|
||||||
|
|
||||||
|
for input_ids, target_ids in buffer:
|
||||||
|
if len(input_ids) > seq_length - len(all_input_ids):
|
||||||
|
input_ids = input_ids[-(seq_length - len(all_input_ids)):]
|
||||||
|
target_ids = target_ids[-(seq_length - len(all_target_ids)):]
|
||||||
|
if len(all_input_ids) > 0 and random.random() < random_concat_ratio:
|
||||||
|
input_ids = input_ids[1:]
|
||||||
|
target_ids = target_ids[1:]
|
||||||
|
all_input_ids.extend(input_ids)
|
||||||
|
all_target_ids.extend(target_ids)
|
||||||
|
if len(all_input_ids) >= seq_length:
|
||||||
|
assert len(all_input_ids) == seq_length, f"{len(all_input_ids)=}, {seq_length=}, {len(buffer)=}"
|
||||||
|
all_input_ids_list.append(all_input_ids)
|
||||||
|
all_target_ids_list.append(all_target_ids)
|
||||||
|
all_input_ids, all_target_ids = [], []
|
||||||
|
|
||||||
|
all_input_ids = all_input_ids + [tokenizer.pad_token_id for i in range(seq_length - len(all_input_ids))]
|
||||||
|
all_target_ids = all_target_ids + [IGNORE_TOKEN_ID for i in range(seq_length - len(all_target_ids))]
|
||||||
|
all_input_ids_list.append(all_input_ids)
|
||||||
|
all_target_ids_list.append(all_target_ids)
|
||||||
|
|
||||||
|
if len(all_input_ids) <= 0:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor(all_input_ids_list, dtype=torch.long),
|
||||||
|
"labels": torch.tensor(all_target_ids_list, dtype=torch.long)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_parallel_groups(ep_size=1):
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "0"))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
ep_group = edp_group = None
|
||||||
|
for i in range(0, world_size, ep_size):
|
||||||
|
ranks = list(range(i, i + ep_size))
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if local_rank in ranks:
|
||||||
|
ep_group = group
|
||||||
|
edp_group = None
|
||||||
|
for i in range(ep_size):
|
||||||
|
ranks = list(range(i, world_size, ep_size))
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if local_rank in ranks:
|
||||||
|
edp_group = group
|
||||||
|
dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group)
|
||||||
|
dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group)
|
||||||
|
return world_size, local_rank, ep_group, edp_group
|
||||||
|
Loading…
Reference in New Issue
Block a user