diff --git a/train_ep.py b/train_ep.py index f26b568..820a952 100644 --- a/train_ep.py +++ b/train_ep.py @@ -110,6 +110,14 @@ def main(): if type(layer.mlp).__name__ != "DeepseekV2MoE": continue layer.mlp.ep_group = ep_group + # Force all2all backward the same number of times + if ep_size > 1 and not expert_config["non_expert_modules"]: + min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v) + mlp = model.model.layers[min_layer_id].mlp + forward = mlp.forward + def custom_forward(self, hidden_states: torch.Tensor): + return forward(hidden_states.requires_grad_(torch.is_grad_enabled())) + mlp.forward = MethodType(custom_forward, mlp) # Initialize Trainer trainer = Trainer(