Merge pull request #6 from GeeeekExplorer/main

Update train_ep.py
This commit is contained in:
Zihan Wang 2024-08-12 10:43:56 +08:00 committed by GitHub
commit cde92cbe45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,6 +110,14 @@ def main():
if type(layer.mlp).__name__ != "DeepseekV2MoE":
continue
layer.mlp.ep_group = ep_group
# Force all2all backward the same number of times
if ep_size > 1 and not expert_config["non_expert_modules"]:
min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v)
mlp = model.model.layers[min_layer_id].mlp
forward = mlp.forward
def custom_forward(self, hidden_states: torch.Tensor):
return forward(hidden_states.requires_grad_(torch.is_grad_enabled()))
mlp.forward = MethodType(custom_forward, mlp)
# Initialize Trainer
trainer = Trainer(