From f27ae60863e58151a265ea2d85f0d0ef96c20f58 Mon Sep 17 00:00:00 2001 From: Xingkai Yu <38156925+GeeeekExplorer@users.noreply.github.com> Date: Mon, 12 Aug 2024 10:43:00 +0800 Subject: [PATCH] Update train_ep.py Force all2all backward the same number of times --- train_ep.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_ep.py b/train_ep.py index f26b568..820a952 100644 --- a/train_ep.py +++ b/train_ep.py @@ -110,6 +110,14 @@ def main(): if type(layer.mlp).__name__ != "DeepseekV2MoE": continue layer.mlp.ep_group = ep_group + # Force all2all backward the same number of times + if ep_size > 1 and not expert_config["non_expert_modules"]: + min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v) + mlp = model.model.layers[min_layer_id].mlp + forward = mlp.forward + def custom_forward(self, hidden_states: torch.Tensor): + return forward(hidden_states.requires_grad_(torch.is_grad_enabled())) + mlp.forward = MethodType(custom_forward, mlp) # Initialize Trainer trainer = Trainer(