From 809d0e377e7c5ce3d4b29d6503d67c31fbb229c2 Mon Sep 17 00:00:00 2001 From: Xingkai Yu Date: Fri, 9 Aug 2024 16:23:48 +0800 Subject: [PATCH 1/2] add training code --- .gitignore | 1 + configs/base.yaml | 31 + deepseek/__init__.py | 0 deepseek/configuration_deepseek.py | 206 +++ deepseek/modeling_deepseek.py | 1918 ++++++++++++++++++++++++++++ esft.py | 33 +- scripts/train.sh | 12 + scripts/train_ep.sh | 11 + train.py | 117 ++ train_ep.py | 154 +++ utils.py | 58 + 11 files changed, 2533 insertions(+), 8 deletions(-) create mode 100644 .gitignore create mode 100644 configs/base.yaml create mode 100644 deepseek/__init__.py create mode 100644 deepseek/configuration_deepseek.py create mode 100644 deepseek/modeling_deepseek.py create mode 100644 scripts/train.sh create mode 100644 scripts/train_ep.sh create mode 100644 train.py create mode 100644 train_ep.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..14b530f --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,31 @@ +seed: 5934875 +# Model settings +seq_length: 4096 # Maximum sequence length + +# Data settings +per_device_batch_size: 1 +n_device: 8 # Number of devices + +# Training settings +optim: adamw_torch_fused +steps: 500 # Number of training steps +learning_rate: 0.00001 # Learning rate +weight_decay: 0.1 # Weight decay for optimizer +warmup_steps: 0 # Number of warmup steps for learning rate scheduler +logging_steps: 10 # Log every X steps +adam_beta1: 0.9 +adam_beta2: 0.95 +random_concat_ratio: 0.2 # Ratio of random concatenation + + +# Evaluation settings +eval_steps: 100 # Evaluate every X steps +save_steps: 100 # Save model every X steps + +# Tokenizer settings + +# Additional settings (if needed) +gradient_checkpointing: true +gradient_accumulation_steps: 16 # Number of updates steps to accumulate before performing a backward/update pass +max_grad_norm: 1.0 # Max gradient norm for gradient clipping +ep_size: 2 \ No newline at end of file diff --git a/deepseek/__init__.py b/deepseek/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepseek/configuration_deepseek.py b/deepseek/configuration_deepseek.py new file mode 100644 index 0000000..82e0f5d --- /dev/null +++ b/deepseek/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV2Model, DeepseekV2Config + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV2Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/deepseek/modeling_deepseek.py b/deepseek/modeling_deepseek.py new file mode 100644 index 0000000..f9465ba --- /dev/null +++ b/deepseek/modeling_deepseek.py @@ -0,0 +1,1918 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV2Config +import torch.distributed as dist + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class All2All(torch.autograd.Function): + @staticmethod + def forward(ctx, input: torch.Tensor, output_splits: List[int], input_splits: List[int], group=None): + ctx.output_splits = output_splits + ctx.input_splits = input_splits + ctx.group = group + output = input.new_empty(sum(output_splits), *input.shape[1:]) \ + if output_splits else torch.empty_like(input) + dist.all_to_all_single(output, input, output_splits, input_splits, group) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + output_splits = ctx.output_splits + input_splits = ctx.input_splits + group = ctx.group + grad_input = grad_output.new_empty(sum(input_splits), *grad_output.shape[1:]) \ + if input_splits else torch.empty_like(grad_output) + dist.all_to_all_single(grad_input, grad_output, input_splits, output_splits, group) + return grad_input, None, None, None + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) + + +class DeepseekV2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV2MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + # convert dtype in ESFT so trainable experts of fp32 can be aggregated with frozen experts of bf16 + if x.dtype != self.up_proj.weight.dtype: + xdtype = x.dtype + x = x.to(self.up_proj.weight.dtype) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = down_proj.to(xdtype) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class DeepseekV2MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.n_routed_experts % config.ep_size == 0 + self.ep_group = None + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() % config.ep_size + self.experts = nn.ModuleList( + [ + ( + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.ep_size == 1: + hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + else: + y = self.moe_ep(hidden_states, topk_idx) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.to(hidden_states.dtype).view(*orig_shape) + if self.training: + y = AddAuxiliaryLoss.apply(y, aux_loss) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + def moe_ep(self, x, topk_ids): + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.n_routed_experts)) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // self.num_experts_per_tok] + if self.ep_size > 1: + tokens_per_expert_group = torch.empty_like(tokens_per_expert) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group) + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(dim=1).cpu().tolist() + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1).cpu().tolist() + gathered_tokens = All2All.apply(sorted_tokens, output_splits, input_splits, self.ep_group) + gatherd_idxs = idxs.new_empty(gathered_tokens.shape[0], device="cpu") + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.to(idxs.device).argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_group.view(self.ep_size, -1).sum(dim=0) + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + if num_tokens == 0: + continue + end_idx = start_idx + num_tokens + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + outputs.append(expert(sorted_tokens[start_idx:end_idx])) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + sorted_tokens = torch.empty_like(outs) + sorted_tokens[gatherd_idxs] = outs + gathered_tokens = All2All.apply(sorted_tokens, input_splits, output_splits, self.ep_group) + outs = gathered_tokens + + y = torch.empty_like(outs) + y[idxs] = outs + return y + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 +class DeepseekV2FlashAttention2(DeepseekV2Attention): + """ + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV2Attention, + "flash_attention_2": DeepseekV2FlashAttention2, +} + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2PreTrainedModel(PreTrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2Model(DeepseekV2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/esft.py b/esft.py index e1620cc..8780a2e 100644 --- a/esft.py +++ b/esft.py @@ -7,6 +7,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer def to_buffer(module, mark_param=True): """Turns all parameters of a module into buffers.""" + if module is None: + return modules = module.modules() module = next(modules) delattrs = [] @@ -25,6 +27,8 @@ def to_buffer(module, mark_param=True): def to_param(module): """Turns all buffers of a module into parameterss.""" + if module is None: + return modules = module.modules() module = next(modules) param_list = getattr(module, 'param_list', []) @@ -57,7 +61,7 @@ def to_esft(model, adapter_config): to_buffer(model) else: to_param(model) - for idx, layer in enumerate(model.layers): + for idx, layer in enumerate(model.model.layers): if type(layer.mlp).__name__ != "DeepseekV2MoE": continue if adapter_config.get('shared_experts', False): @@ -72,15 +76,25 @@ def to_esft(model, adapter_config): to_buffer(layer.mlp.experts[expert_id]) return model + def load_state_dict(folder_path): + # 初始化空的 state_dict combined_state_dict = {} + # 遍历文件夹中的所有文件 for file_name in os.listdir(folder_path): if file_name.endswith('.safetensors'): file_path = os.path.join(folder_path, file_name) state_dict = load_file(file_path) combined_state_dict.update(state_dict) - + + # legacy for loading v1 checkpoints: add prefix "model." for parameters + for k in list(combined_state_dict.keys()): + if k.startswith("layers"): + k_new = "model." + k + combined_state_dict[k_new] = combined_state_dict[k] + del combined_state_dict[k] + return combined_state_dict @@ -89,21 +103,24 @@ def load_esft_model(base_model_path, adapter_dir): adapter_state_dict = load_state_dict(adapter_dir) # load pretrained model: - model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path) + model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path) - to_esft(model.model, adapter_config) - model.model.load_state_dict(adapter_state_dict) + to_esft(model, adapter_config) + model.load_state_dict(adapter_state_dict) return model, tokenizer def load_base_model(base_model_path): # load pretrained model: - model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path) + model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path) return model, tokenizer -def add_adapter(base_model, adapter_dir, return_original_states=False): - adapter_config = json.load(open(adapter_dir + "/expert_cfg.json")) +def add_adapter(base_model, adapter_dir, return_original_states=False, expert_config=None): + if expert_config is not None: + adapter_config = json.load(open(expert_config)) + else: + adapter_config = json.load(open(adapter_dir + "/expert_cfg.json")) adapter_state_dict = load_state_dict(adapter_dir) to_esft(base_model, adapter_config) diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000..5671421 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,12 @@ + +export TOKENIZERS_PARALLELISM=false + +exp_name="test/eval_translation" +base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model" +# turn above to for loop +python train.py \ + --base_model_path=${base_model_path} \ + --expert_config=results/expert_configs/translation.json \ + --train_dataset=translation \ + --train_config=configs/base.yaml \ + --output_dir=results/checkpoints/${exp_name} \ No newline at end of file diff --git a/scripts/train_ep.sh b/scripts/train_ep.sh new file mode 100644 index 0000000..0785e9a --- /dev/null +++ b/scripts/train_ep.sh @@ -0,0 +1,11 @@ + +export TOKENIZERS_PARALLELISM=false + +exp_name="test/eval_translation" +base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model" +torchrun --nproc-per-node=8 train_ep.py \ + --base_model_path=${base_model_path} \ + --expert_config=results/expert_configs/translation.json \ + --train_dataset=translation \ + --train_config=configs/base.yaml \ + --output_dir=results/checkpoints/${exp_name} \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..6be9bda --- /dev/null +++ b/train.py @@ -0,0 +1,117 @@ +import argparse +import json +import yaml +import os +import random +import torch +from torch.utils.data import TensorDataset +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging + +from benchmarks import * +from utils import get_formatted_input_and_target, get_examples_from_buffer_pad +from esft import to_esft +from deepseek.modeling_deepseek import DeepseekV2ForCausalLM + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--base_model_path", type=str, required=True) + parser.add_argument("--expert_config", type=str, required=True) + parser.add_argument("--train_dataset", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--train_config", type=str, required=True) + parser.add_argument("--wandb_api_key", type=str, required=False) + args = parser.parse_args() + + expert_config = json.load(open(args.expert_config)) + output_dir = args.output_dir + base_model_path = args.base_model_path + config = yaml.safe_load(open(args.train_config)) + os.makedirs(args.output_dir, exist_ok=True) + + seed = config['seed'] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + + if args.wandb_api_key is not None: + import wandb + wandb.login(key=args.wandb_api_key) + + # Prepare data + tokenizer = AutoTokenizer.from_pretrained(base_model_path) + samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()] + buffer = [] + for instance in samples: + input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) + buffer.append((input_ids, target_ids)) + seq_length = config['seq_length'] + random_concat_ratio = config['random_concat_ratio'] + concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio) + + dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels']) + train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)]) + + # Training arguments + training_args = TrainingArguments( + output_dir=output_dir, + max_steps=config['steps'], + per_device_train_batch_size=config['per_device_batch_size'], + per_device_eval_batch_size=config['per_device_batch_size'], + warmup_steps=config['warmup_steps'], + weight_decay=config['weight_decay'], + logging_dir=f"{output_dir}/logs", + logging_steps=config['logging_steps'], + save_steps=config['save_steps'], + eval_strategy="steps", + eval_steps=config['eval_steps'], + gradient_accumulation_steps=config['gradient_accumulation_steps'], + load_best_model_at_end=True, + metric_for_best_model="loss", + greater_is_better=False, + bf16=True, + lr_scheduler_type='constant', + save_total_limit=5, + learning_rate=config['learning_rate'], + optim=config['optim'], + adam_beta1=config['adam_beta1'], + adam_beta2=config['adam_beta2'], + gradient_checkpointing=config['gradient_checkpointing'], + gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug + ) + + def data_collator(data): + input_ids = torch.stack([item[0] for item in data]) + labels = torch.stack([item[1] for item in data]) + return {"input_ids": input_ids, "labels": labels} + + + model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") + to_esft(model, expert_config) + + # Initialize Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=valid_dataset, + data_collator=data_collator, + ) + # Training + if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + # Save the model and tokenizer + trainer.save_model(output_dir + "/last_checkpoint") + tokenizer.save_pretrained(output_dir + "/last_checkpoint") + + print("Training complete") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_ep.py b/train_ep.py new file mode 100644 index 0000000..3ba98e8 --- /dev/null +++ b/train_ep.py @@ -0,0 +1,154 @@ +import argparse +import json +import yaml +import os +import random +import torch +import torch.distributed as dist +from types import MethodType +from torch.utils.data import TensorDataset +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging + +from benchmarks import * +from utils import get_formatted_input_and_target, get_examples_from_buffer_pad, init_parallel_groups +from esft import to_esft +from deepseek.modeling_deepseek import DeepseekV2ForCausalLM + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["NCCL_AVOID_RECORD_STREAMS"] = "1" +logging.set_verbosity_error() + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--base_model_path", type=str, required=True) + parser.add_argument("--expert_config", type=str, required=True) + parser.add_argument("--train_dataset", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--train_config", type=str, required=True) + parser.add_argument("--wandb_api_key", type=str, required=False) + args = parser.parse_args() + + expert_config = json.load(open(args.expert_config)) + output_dir = args.output_dir + base_model_path = args.base_model_path + config = yaml.safe_load(open(args.train_config)) + os.makedirs(args.output_dir, exist_ok=True) + + seed = config['seed'] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + + if args.wandb_api_key is not None: + import wandb + wandb.login(key=args.wandb_api_key) + + ep_size = config.get("ep_size", 1) + world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size) + edp_size = world_size // ep_size + + # Prepare data + tokenizer = AutoTokenizer.from_pretrained(base_model_path) + samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()] + buffer = [] + for instance in samples: + input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) + buffer.append((input_ids, target_ids)) + seq_length = config['seq_length'] + random_concat_ratio = config['random_concat_ratio'] + concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio) + + dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels']) + train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)]) + + # Training arguments + training_args = TrainingArguments( + output_dir=output_dir, + max_steps=config['steps'], + per_device_train_batch_size=config['per_device_batch_size'], + per_device_eval_batch_size=config['per_device_batch_size'], + warmup_steps=config['warmup_steps'], + weight_decay=config['weight_decay'], + logging_dir=f"{output_dir}/logs", + logging_steps=config['logging_steps'], + save_steps=config['save_steps'], + eval_strategy="steps", + eval_steps=config['eval_steps'], + gradient_accumulation_steps=config['gradient_accumulation_steps'], + load_best_model_at_end=True, + metric_for_best_model="loss", + greater_is_better=False, + bf16=True, + lr_scheduler_type='constant', + save_total_limit=5, + learning_rate=config['learning_rate'], + optim=config['optim'], + adam_beta1=config['adam_beta1'], + adam_beta2=config['adam_beta2'], + disable_tqdm=False, + gradient_checkpointing=config['gradient_checkpointing'], + gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug + ) + + def data_collator(data): + input_ids = torch.stack([item[0] for item in data]) + labels = torch.stack([item[1] for item in data]) + return {"input_ids": input_ids, "labels": labels} + + + model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, + ep_size=ep_size, attn_implementation="flash_attention_2") + model._ddp_params_and_buffers_to_ignore = [n for n, _ in model.named_parameters() if ".expert" in n] # we manage grad synchronization of expert parameters + to_esft(model, expert_config) + model.dummy = torch.nn.Parameter(torch.zeros(1, dtype=model.dtype)) # prevent DDP from having no trainable parameters + model._keys_to_ignore_on_save = ["dummy"] + expert_params = [p for n, p in model.named_parameters() if p.requires_grad and ".expert" in n] + for layer in model.model.layers: + if type(layer.mlp).__name__ != "DeepseekV2MoE": + continue + layer.mlp.ep_group = ep_group + + # Initialize Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=valid_dataset, + data_collator=data_collator, + ) + + accelerator = trainer.accelerator + backward = accelerator.backward + def custom_backward(self, loss, **kwargs): + backward(loss, **kwargs) + if not self.sync_gradients or edp_size == 1: + return + return + for p in expert_params: + g = p.grad if p.grad is not None else torch.zeros_like(p) + dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group) + if p.grad is not g: + p.grad = g + accelerator.backward = MethodType(custom_backward, accelerator) + + # Training + ckpt_path = f"{output_dir}/last_checkpoint_ep{local_rank}" + if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 1: # has checkpoints already + trainer.train(resume_from_checkpoint=ckpt_path) + else: + trainer.train() + + # Save the model and tokenizer + if local_rank == 0: + trainer.save_model(ckpt_path) + tokenizer.save_pretrained(ckpt_path) + elif 0 < local_rank < ep_size: + model.save_pretrained(ckpt_path) + + print("Training complete") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils.py b/utils.py index ff1cd65..f7da51f 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,7 @@ +import os +import random +import torch +import torch.distributed as dist # given a message object, convert to prompt and response PROMPT_USER: str = 'User: {input}\n\n' @@ -38,3 +42,57 @@ def get_formatted_input_and_target(messages, tokenizer, IGNORE_TOKEN_ID=-100, ma assert False, f"Unknown role: {message['role']}" return [input_ids, target_ids] + + +def get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio, IGNORE_TOKEN_ID=-100): + all_input_ids_list, all_target_ids_list = [], [] + all_input_ids, all_target_ids = [], [] + + for input_ids, target_ids in buffer: + if len(input_ids) > seq_length - len(all_input_ids): + input_ids = input_ids[-(seq_length - len(all_input_ids)):] + target_ids = target_ids[-(seq_length - len(all_target_ids)):] + if len(all_input_ids) > 0 and random.random() < random_concat_ratio: + input_ids = input_ids[1:] + target_ids = target_ids[1:] + all_input_ids.extend(input_ids) + all_target_ids.extend(target_ids) + if len(all_input_ids) >= seq_length: + assert len(all_input_ids) == seq_length, f"{len(all_input_ids)=}, {seq_length=}, {len(buffer)=}" + all_input_ids_list.append(all_input_ids) + all_target_ids_list.append(all_target_ids) + all_input_ids, all_target_ids = [], [] + + all_input_ids = all_input_ids + [tokenizer.pad_token_id for i in range(seq_length - len(all_input_ids))] + all_target_ids = all_target_ids + [IGNORE_TOKEN_ID for i in range(seq_length - len(all_target_ids))] + all_input_ids_list.append(all_input_ids) + all_target_ids_list.append(all_target_ids) + + if len(all_input_ids) <= 0: + return None + return { + "input_ids": torch.tensor(all_input_ids_list, dtype=torch.long), + "labels": torch.tensor(all_target_ids_list, dtype=torch.long) + } + + +def init_parallel_groups(ep_size=1): + dist.init_process_group("nccl") + world_size = int(os.getenv("WORLD_SIZE", "0")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + ep_group = edp_group = None + for i in range(0, world_size, ep_size): + ranks = list(range(i, i + ep_size)) + group = dist.new_group(ranks) + if local_rank in ranks: + ep_group = group + edp_group = None + for i in range(ep_size): + ranks = list(range(i, world_size, ep_size)) + group = dist.new_group(ranks) + if local_rank in ranks: + edp_group = group + dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group) + dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group) + return world_size, local_rank, ep_group, edp_group From 36c6194349e42969bfb782fe7583529e4ba86783 Mon Sep 17 00:00:00 2001 From: ZihanWang314 <510642032wzh@gmail.com> Date: Fri, 9 Aug 2024 18:06:57 +0800 Subject: [PATCH 2/2] update eval and readme --- README.md | 59 +++++++---- eval_multigpu.py | 92 ++++++++++++++++++ scripts/eval.sh | 18 +++- ...nerate_expert_config.sh => eval_expert.sh} | 12 ++- scripts/expert/generate_expert_config.py | 97 +++++++++++++++++++ scripts/expert/get_expert_scores.py | 78 +++++++++++++++ scripts/generate_expert_config.py | 50 ---------- scripts/get_expert_scores.py | 75 -------------- scripts/train.sh | 2 +- scripts/train_ep.sh | 2 +- 10 files changed, 333 insertions(+), 152 deletions(-) create mode 100644 eval_multigpu.py rename scripts/{generate_expert_config.sh => eval_expert.sh} (58%) create mode 100644 scripts/expert/generate_expert_config.py create mode 100644 scripts/expert/get_expert_scores.py delete mode 100644 scripts/generate_expert_config.py delete mode 100644 scripts/get_expert_scores.py diff --git a/README.md b/README.md index f2ce48e..e001aab 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ Y. Wu. **ESFT** aims to efficiently customize Large Language Models (LLMs) with Mixture-of-Experts (MoE) architecture by adjusting only task-relevant parts, improving efficiency and performance while using fewer resources and storage. +## 📰 News + +📅 **2024.8.11:** We now release the **ESFT training code**! ✨ You can now try it with your own models and dataset! ## 🚀 Quick Start @@ -19,9 +22,9 @@ git clone https://github.com/deepseek-ai/ESFT.git cd esft ``` -### Install dependencies +### Install required dependencies ```bash -pip install transformers torch safetensors +pip install transformers torch safetensors accelerate ``` ### Download necessary adapters @@ -32,35 +35,38 @@ bash scripts/download_adapters.sh ## 🔧Key Scripts -1. **eval.py** -This script evaluates the performance of the model on various datasets. **Usage:** +1. **eval_multigpu.py** +This script evaluates the performance of the model on various datasets. See **scripts/eval.sh** for detailed configs and explanations. + +**Usage:** ```bash -python scripts/eval.py \ - --eval_datasets=translation \ +python eval_multigpu.py \ + --eval_dataset=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --adapter_dir=all_models/adapters/token \ - --output_dir=results/completions/token \ - --max_new_tokens=512 \ - --openai_api_key=REPLACE_WITH_YOUR_KEY \ - --eval_batch_size=2 + --adapter_dir=all_models/adapters/token/translation \ + --output_path=results/completions/token/translation.jsonl \ + --openai_api_key=YOUR_OPENAI_API_KEY ``` + 2. **get_expert_scores.py** This script calculates the scores for each expert based on the evaluation datasets. **Usage:** ```bash -python scripts/get_expert_scores.py \ - --eval_datasets=intent,summary,law,translation \ +python scripts/expert/get_expert_scores.py \ + --eval_dataset=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --output_dir=results/expert_scores \ - --n_sample_tokens=8192 # the sample size hyperparameter + --output_dir=results/expert_scores/translation \ + --n_sample_tokens=131072 \ + --world_size=4 \ + --gpus_per_rank=2 ``` 3. **generate_expert_config.py** This script generates the configuration to convert a MoE model with only task-relevant tasks trained based on evaluation scores. **Usage:** ```bash -python scripts/generate_expert_config.py \ +python scripts/expert/generate_expert_config.py \ --eval_datasets=intent,summary,law,translation \ --expert_scores_dir=results/expert_scores \ --output_dir=results/expert_configs \ @@ -68,13 +74,32 @@ python scripts/generate_expert_config.py \ --top_p=0.2 # the scoring function and top_p are hyperparameters ``` +4. **train.py** and **train_ep.py** +This script trains the model with the expert configuration generated by the previous script. The train_ep.py file uses expert parallel and has been optimized for multi-GPU training. +**Usage:** +```bash +python train.py \ + --base_model_path=deepseek-ai/ESFT-vanilla-lite \ + --expert_config=results/expert_configs/intent.json \ + --train_dataset=intent \ + --train_config=configs/base.yaml \ + --output_dir=results/checkpoints/intent + +torchrun --nproc-per-node=8 train_ep.py \ + --base_model_path=deepseek-ai/ESFT-vanilla-lite \ + --expert_config=results/expert_configs/translation.json \ + --train_dataset=translation \ + --train_config=configs/base.yaml \ + --output_dir=results/checkpoints/translation + +``` ## Contact and Support For bug reports, feature requests, and general inquiries, please open an issue on our GitHub Issues page. Make sure to include as much detail as possible to help us address your issue quickly. ## 🌟Todo list - ☑️ 📝 Update models, evaluation scripts, and expert selection scripts -- 🔲 🔧 Update training scripts +- ☑️ 🔧 Update training scripts - 🔲 🚀 More... diff --git a/eval_multigpu.py b/eval_multigpu.py new file mode 100644 index 0000000..5516166 --- /dev/null +++ b/eval_multigpu.py @@ -0,0 +1,92 @@ +import json +import argparse + +from torch import device +from benchmarks import * +import os +from esft import load_base_model, add_adapter +import torch.multiprocessing as mp +from itertools import accumulate +from accelerate import dispatch_model +from transformers import AutoModelForCausalLM, AutoTokenizer + +def infer_auto_device_map(model, pp_splits, visible_devices): + assert len(pp_splits) == len(visible_devices) + device_map = { + "model.embed_tokens": 0, + "model.norm": len(pp_splits) - 1, + "lm_head": len(pp_splits) - 1 + } + assert len(model.model.layers) == sum(pp_splits) + pp_splits = [0, *list(accumulate(pp_splits))] + for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])): + for i in range(start, end): + device_map.update({f"model.layers.{i}": idx}) + for k, v in device_map.items(): + device_map[k] = visible_devices[v] + return device_map + + +def eval_model(rank, args, model, dataset): + config = { + "max_new_tokens": args.max_new_tokens, + "eval_batch_size": args.eval_batch_size, + "openai_api_key": args.openai_api_key + } + evaluator_map = { + "intent": IntentEvaluator, + "summary": SummaryEvaluator, + "law": LawEvaluator, + "translation": TranslationEvaluator + } + try: + evaluator_cls = evaluator_map[args.eval_dataset] + print(f"Rank {rank} starting evaluation...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) + visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank)) + device_map = infer_auto_device_map(model, [14, 13], visible_devices) + model = dispatch_model(model, device_map) + cur_dataset = dataset[rank::args.world_size] + evaluator = evaluator_cls(cur_dataset, config) + with torch.no_grad(): + results, metrics = evaluator.evaluate(model, tokenizer) + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path + f".rank_{rank}", "w") as f: + for res, m in zip(results, metrics): + obj = { + "example": res, + "score": m + } + f.write(json.dumps(obj, ensure_ascii=False) + "\n") + + except Exception as e: + print(f"Error in process {rank}: {e}", flush=True) + raise + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.") + parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset") + parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model") + parser.add_argument("--adapter_dir", type=str, required=True, help="Directory containing the adapter") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the evaluation results") + parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens") + parser.add_argument("--openai_api_key", type=str, required=True, help="API key for OpenAI") + parser.add_argument("--eval_batch_size", type=int, default=1, help="Batch size for evaluation") + parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation") + parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process") + + args = parser.parse_args() + + + + print("Loading base model...") + model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock + + print(f"Running evaluation on {args.eval_dataset}...") + dataset = [json.loads(i) for i in open(f"datasets/eval/{args.eval_dataset}.jsonl").readlines()] + + print("Adding adapter...") + model = add_adapter(model, args.adapter_dir, return_original_states=False) + + print("Start Evaluating...") + mp.spawn(eval_model, args=(args, model, dataset), nprocs=args.world_size, join=True) diff --git a/scripts/eval.sh b/scripts/eval.sh index dc1f69a..b86d88c 100644 --- a/scripts/eval.sh +++ b/scripts/eval.sh @@ -1,12 +1,24 @@ -# first, download adapter models and put them to the corresponding directories +# first: download adapter models and put them to the corresponding directories -python scripts/eval.py \ +python eval_multigpu.py \ --eval_datasets=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ --adapter_dir=all_models/adapters/token \ --output_dir=results/completions/token \ --max_new_tokens=512 \ --openai_api_key=REPLACE_WITH_YOUR_KEY \ - --eval_batch_size=2 + --eval_batch_size=2 \ + --world_size=4 \ + --gpus_per_rank=2 +# this script is used for single-gpu training and has been deprecated. If you have no multiple gpus, you can set above world_size=1 and gpus_per_rank=1 + +# python scripts/eval.py \ +# --eval_datasets=translation \ +# --base_model_path=deepseek-ai/ESFT-vanilla-lite \ +# --adapter_dir=all_models/adapters/token \ +# --output_dir=results/completions/token \ +# --max_new_tokens=512 \ +# --openai_api_key=REPLACE_WITH_YOUR_KEY \ +# --eval_batch_size=2 diff --git a/scripts/generate_expert_config.sh b/scripts/eval_expert.sh similarity index 58% rename from scripts/generate_expert_config.sh rename to scripts/eval_expert.sh index 13be3f4..2c8e8aa 100644 --- a/scripts/generate_expert_config.sh +++ b/scripts/eval_expert.sh @@ -1,10 +1,12 @@ -python scripts/get_expert_scores.py \ - --eval_datasets=intent,summary,law,translation \ +python scripts/expert/get_expert_scores.py \ + --eval_dataset=translation \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --output_dir=results/expert_scores \ - --n_sample_tokens=8192 # this sample size is a hyperparameter + --output_dir=results/expert_scores/translation \ + --n_sample_tokens=131072 \ + --world_size=4 \ + --gpus_per_rank=2 -python scripts/generate_expert_config.py \ +python scripts/expert/generate_expert_config.py \ --eval_datasets=intent,summary,law,translation \ --expert_scores_dir=results/expert_scores \ --output_dir=results/expert_configs \ diff --git a/scripts/expert/generate_expert_config.py b/scripts/expert/generate_expert_config.py new file mode 100644 index 0000000..8ed4a8e --- /dev/null +++ b/scripts/expert/generate_expert_config.py @@ -0,0 +1,97 @@ +import argparse +import json +import os +from multiprocessing import Pool +import numpy as np + + +def parse_line(line): + expert_ids, expert_weights = line.split("\t\t") + expert_ids = [int(i) for i in expert_ids.split("\t")] + expert_weights = [float(i) for i in expert_weights.split("\t")] + return expert_ids, expert_weights + + +def get_summary(files): + TOP_K=6 + N_EXPERTS=64 + N_LAYERS=26 # 27 layers totally, the first layer is not MoE + + gate_scores = np.zeros((N_LAYERS, N_EXPERTS)) + token_scores = np.zeros((N_LAYERS, N_EXPERTS)) + + print("loading files") + for rank, file in files: + layer_id = int(file.split(".")[0].split("_")[2]) - 1 + + with open(os.path.join(args.expert_scores_dir, rank, file)) as f: + data = f.readlines() + for line in data: + expert_ids, expert_weights = parse_line(line) + np.add.at(gate_scores[layer_id], expert_ids, expert_weights) + np.add.at(token_scores[layer_id], expert_ids, np.ones_like(expert_weights) / TOP_K) + + gate_scores = gate_scores / np.sum(gate_scores, axis=0) + token_scores = token_scores / np.sum(token_scores, axis=0) + + summary = {"token_scores": token_scores, "gate_scores": gate_scores} + summary = {k: {str(i+1): {str(j): round(v, 4) for j, v in enumerate(l)} for i, l in enumerate(v)} for k, v in summary.items()} + + return summary + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--eval_dataset", type=str, required=True) + parser.add_argument("--expert_scores_dir", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--score_function", type=str, required=True) + parser.add_argument("--top_p", type=float, required=True) + parser.add_argument("--train_shared_experts", action="store_true") + parser.add_argument("--train_non_expert_modules", action="store_true") + + args = parser.parse_args() + + expert_cfg = { # initialize expert config + "experts": {}, + "shared_experts": args.train_shared_experts, + "non_expert_modules": args.train_non_expert_modules + } + + # let's walk inside args.expert_scores_dir and get abs file names + file_names = [] + for rank in [i for i in os.listdir(args.expert_scores_dir) if 'rank' in i]: + for file in os.listdir(os.path.join(args.expert_scores_dir, rank)): + file_names.append([rank, file]) + + + summary_file = os.path.join(args.expert_scores_dir, "summary.json") + summary = get_summary(file_names) + + with open(summary_file, "w") as f: + f.write(json.dumps(summary)) + + + scores = summary[f"{args.score_function}_scores"] + for layer, l_score in scores.items(): + l_score = [(int(k), v) for k,v in l_score.items()] + l_score = sorted(l_score, key=lambda x: x[1], reverse=True) + selected_experts = [] + current_score = 0 + for expert, score in l_score: + if current_score >= args.top_p: + break + selected_experts.append(expert) + current_score += score + expert_cfg["experts"][layer] = selected_experts + + top_p = args.top_p + train_shared_experts = args.train_shared_experts + train_non_expert_modules = args.train_non_expert_modules + + + + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path, "w") as f: + json.dump(expert_cfg, f) diff --git a/scripts/expert/get_expert_scores.py b/scripts/expert/get_expert_scores.py new file mode 100644 index 0000000..bd94e88 --- /dev/null +++ b/scripts/expert/get_expert_scores.py @@ -0,0 +1,78 @@ +import json +import os +import torch +import argparse +import random +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import get_formatted_input_and_target +import torch.multiprocessing as mp +from itertools import accumulate +from accelerate import dispatch_model + + +def infer_auto_device_map(model, pp_splits, visible_devices): + assert len(pp_splits) == len(visible_devices) + device_map = { + "model.embed_tokens": 0, + "model.norm": len(pp_splits) - 1, + "lm_head": len(pp_splits) - 1 + } + assert len(model.model.layers) == sum(pp_splits) + pp_splits = [0, *list(accumulate(pp_splits))] + for idx, (start, end) in enumerate(zip(pp_splits[:-1], pp_splits[1:])): + for i in range(start, end): + device_map.update({f"model.layers.{i}": idx}) + for k, v in device_map.items(): + device_map[k] = visible_devices[v] + return device_map + + +def eval_expert(rank, args, model, dataset): + try: + print(f"Rank {rank} starting expert evaluation...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) + visible_devices = list(range(rank * args.gpus_per_rank, (rank + 1) * args.gpus_per_rank)) + device_map = infer_auto_device_map(model, [14, 13], visible_devices) + model = dispatch_model(model, device_map) + model.config.expert_log_dir = os.path.join(args.output_dir, f"rank_{rank}") + n_sample_tokens = args.n_sample_tokens // args.world_size + os.makedirs(os.path.join(args.output_dir, f"rank_{rank}"), exist_ok=True) + done_tokens = 0 + cur_dataset = dataset[rank::args.world_size] + for instance in cur_dataset: + input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) + model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0)) + done_tokens += len(input_ids) + if done_tokens >= n_sample_tokens: + break + + + except Exception as e: + print(f"Error in process {rank}: {e}", flush=True) + raise + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a model with adapters on a specified dataset.") + parser.add_argument("--eval_dataset", type=str, required=True, help="Name of the evaluation dataset") + parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model") + parser.add_argument("--output_dir", type=str, required=True, help="Path to save the evaluation results") + parser.add_argument("--world_size", type=int, default=4, help="Number of processes to use for evaluation") + parser.add_argument("--gpus_per_rank", type=int, default=2, help="Number of GPUs per process") + parser.add_argument("--n_sample_tokens", type=int, required=True, help="Token to sample for expert evaluation") + args = parser.parse_args() + random.seed(5934875) + + + print("Loading base model...") + model = AutoModelForCausalLM.from_pretrained(args.base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) # not using tokenizer here to aviod deadlock + model.config.log_expert_weights = True + + + print(f"Running expert evaluation on {args.eval_dataset}...") + dataset = [json.loads(i) for i in open(f"datasets/train/{args.eval_dataset}.jsonl").readlines()] + random.shuffle(dataset) + + + print("Start Evaluating...") + mp.spawn(eval_expert, args=(args, model, dataset), nprocs=args.world_size, join=True) diff --git a/scripts/generate_expert_config.py b/scripts/generate_expert_config.py deleted file mode 100644 index d641cff..0000000 --- a/scripts/generate_expert_config.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -import json -import os - -parser = argparse.ArgumentParser() -parser.add_argument("--eval_datasets", type=str, required=True) -parser.add_argument("--expert_scores_dir", type=str, required=True) -parser.add_argument("--output_dir", type=str, required=True) -parser.add_argument("--score_function", type=str, required=True) -parser.add_argument("--top_p", type=float, required=True) -parser.add_argument("--train_shared_experts", action="store_true") -parser.add_argument("--train_non_expert_modules", action="store_true") - -args = parser.parse_args() - -eval_datasets = args.eval_datasets.split(",") -expert_scores_dir = args.expert_scores_dir -output_dir = args.output_dir -score_function = args.score_function -top_p = args.top_p -train_shared_experts = args.train_shared_experts -train_non_expert_modules = args.train_non_expert_modules - -for dataset_name in eval_datasets: - summary_file = f"{expert_scores_dir}/{dataset_name}/summary.json" - expert_cfg = {"experts": {}, "shared_experts": train_shared_experts, "non_expert_modules": train_non_expert_modules} - - with open(summary_file) as f: - data = json.load(f) - assert score_function in ["gate", "token"], f"Unknown score function: {score_function}" - scores = data[f"{score_function}_scores"] - - for layer, l_score in scores.items(): - l_score = [(int(k), v) for k,v in l_score.items()] - l_score = sorted(l_score, key=lambda x: x[1], reverse=True) - # get the top experts that make the threshold exceed top_p - selected_experts = [] - current_score = 0 - for expert, score in l_score: - if current_score >= top_p: - break - selected_experts.append(expert) - current_score += score - expert_cfg["experts"][layer] = selected_experts - - os.makedirs(output_dir, exist_ok=True) - with open(f"{output_dir}/{dataset_name}.json", "w") as f: - json.dump(expert_cfg, f) - - diff --git a/scripts/get_expert_scores.py b/scripts/get_expert_scores.py deleted file mode 100644 index 34dfd90..0000000 --- a/scripts/get_expert_scores.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -from benchmarks import * -import os -import torch -from torch import nn -import argparse -from random import shuffle -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_formatted_input_and_target - -# constants for deepseek-v2-lite -TOP_K=6 -N_EXPERTS=64 - -parser = argparse.ArgumentParser() -parser.add_argument("--base_model_path", type=str, required=True) -parser.add_argument("--eval_datasets", type=str, required=True) -parser.add_argument("--output_dir", type=str, required=True) -parser.add_argument("--n_sample_tokens", type=int, required=True) -args = parser.parse_args() - -eval_datasets = args.eval_datasets.split(",") -output_dir = args.output_dir -base_model_path = args.base_model_path -n_sample_tokens = args.n_sample_tokens - -model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"), AutoTokenizer.from_pretrained(base_model_path) -model.config.log_expert_weights = True - -for dataset_name in eval_datasets: - dataset = [json.loads(i) for i in open(f"datasets/train/{dataset_name}.jsonl").readlines()] - shuffle(dataset) - model.config.expert_log_dir = os.path.join(args.output_dir, dataset_name) - # make dir -p this - os.makedirs(os.path.join(args.output_dir, dataset_name), exist_ok=True) - done_tokens = 0 - for instance in dataset: - input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) - model(input_ids=torch.tensor(input_ids).unsqueeze(0), labels=torch.tensor(target_ids).unsqueeze(0)) - done_tokens += len(input_ids) - if done_tokens >= n_sample_tokens: - break - - # open all files under os.path.join(args.output_dir, dataset_name). For each file, generate a summary of it - # and write it to a file in the same directory - files = os.listdir(os.path.join(args.output_dir, dataset_name)) - summary_file = os.path.join(args.output_dir, dataset_name, "summary.json") - token_scores = {} - gate_scores = {} - - for file in files: - if not file.endswith(".txt"): - continue - layer_idx = file.split("_")[2].split(".")[0] - token_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)} - gate_scores[layer_idx] = {expert:0 for expert in range(N_EXPERTS)} - - with open(os.path.join(args.output_dir, dataset_name, file)) as f: - data = f.readlines() - for line in data: - expert_ids, expert_weights = line.split("\t\t") - expert_ids = [int(i) for i in expert_ids.split("\t")] - expert_weights = [float(i) for i in expert_weights.split("\t")] - for expert_id, expert_weight in zip(expert_ids, expert_weights): - gate_scores[layer_idx][expert_id] += expert_weight - token_scores[layer_idx][expert_id] += 1. / TOP_K - total = sum(token_scores[layer_idx].values()) - gate_scores[layer_idx] = {expert: round(gate_scores[layer_idx][expert] / total, 4) for expert in gate_scores[layer_idx]} - token_scores[layer_idx] = {expert: round(token_scores[layer_idx][expert] / total, 4) for expert in token_scores[layer_idx]} - - - with open(summary_file, "w") as f: - f.write(json.dumps({"token_scores": token_scores, "gate_scores": gate_scores})) - - diff --git a/scripts/train.sh b/scripts/train.sh index 5671421..2a33603 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,7 +2,7 @@ export TOKENIZERS_PARALLELISM=false exp_name="test/eval_translation" -base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model" +base_model_path="deepseek-ai/esft-vanilla-lite" # turn above to for loop python train.py \ --base_model_path=${base_model_path} \ diff --git a/scripts/train_ep.sh b/scripts/train_ep.sh index 0785e9a..e6f702d 100644 --- a/scripts/train_ep.sh +++ b/scripts/train_ep.sh @@ -2,7 +2,7 @@ export TOKENIZERS_PARALLELISM=false exp_name="test/eval_translation" -base_model_path="/hf3fs-jd/prod/deepseek/shared/wangzihan/models/huggingface/vanilla_model" +base_model_path="deepseek-ai/esft-vanilla-lite" torchrun --nproc-per-node=8 train_ep.py \ --base_model_path=${base_model_path} \ --expert_config=results/expert_configs/translation.json \