mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-21 19:17:39 +00:00
Update train_ep.py
Force all2all backward the same number of times
This commit is contained in:
parent
7a96df9be6
commit
f27ae60863
@ -110,6 +110,14 @@ def main():
|
|||||||
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
||||||
continue
|
continue
|
||||||
layer.mlp.ep_group = ep_group
|
layer.mlp.ep_group = ep_group
|
||||||
|
# Force all2all backward the same number of times
|
||||||
|
if ep_size > 1 and not expert_config["non_expert_modules"]:
|
||||||
|
min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v)
|
||||||
|
mlp = model.model.layers[min_layer_id].mlp
|
||||||
|
forward = mlp.forward
|
||||||
|
def custom_forward(self, hidden_states: torch.Tensor):
|
||||||
|
return forward(hidden_states.requires_grad_(torch.is_grad_enabled()))
|
||||||
|
mlp.forward = MethodType(custom_forward, mlp)
|
||||||
|
|
||||||
# Initialize Trainer
|
# Initialize Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
Loading…
Reference in New Issue
Block a user