mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-22 11:37:57 +00:00
commit
cde92cbe45
@ -110,6 +110,14 @@ def main():
|
||||
if type(layer.mlp).__name__ != "DeepseekV2MoE":
|
||||
continue
|
||||
layer.mlp.ep_group = ep_group
|
||||
# Force all2all backward the same number of times
|
||||
if ep_size > 1 and not expert_config["non_expert_modules"]:
|
||||
min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v)
|
||||
mlp = model.model.layers[min_layer_id].mlp
|
||||
forward = mlp.forward
|
||||
def custom_forward(self, hidden_states: torch.Tensor):
|
||||
return forward(hidden_states.requires_grad_(torch.is_grad_enabled()))
|
||||
mlp.forward = MethodType(custom_forward, mlp)
|
||||
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
|
Loading…
Reference in New Issue
Block a user