Merge pull request #6 from GeeeekExplorer/main

Update train_ep.py
This commit is contained in:
Zihan Wang
2024-08-12 10:43:56 +08:00
committed by GitHub

View File

@@ -110,6 +110,14 @@ def main():
if type(layer.mlp).__name__ != "DeepseekV2MoE":
continue
layer.mlp.ep_group = ep_group
# Force all2all backward the same number of times
if ep_size > 1 and not expert_config["non_expert_modules"]:
min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v)
mlp = model.model.layers[min_layer_id].mlp
forward = mlp.forward
def custom_forward(self, hidden_states: torch.Tensor):
return forward(hidden_states.requires_grad_(torch.is_grad_enabled()))
mlp.forward = MethodType(custom_forward, mlp)
# Initialize Trainer
trainer = Trainer(