mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-21 19:17:39 +00:00
Update train_ep.py
This commit is contained in:
parent
4bfa99486b
commit
7a96df9be6
@ -126,7 +126,6 @@ def main():
|
||||
backward(loss, **kwargs)
|
||||
if not self.sync_gradients or edp_size == 1:
|
||||
return
|
||||
return
|
||||
for p in expert_params:
|
||||
g = p.grad if p.grad is not None else torch.zeros_like(p)
|
||||
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
|
||||
@ -145,10 +144,10 @@ def main():
|
||||
if local_rank == 0:
|
||||
trainer.save_model(ckpt_path)
|
||||
tokenizer.save_pretrained(ckpt_path)
|
||||
elif 0 < local_rank < ep_size:
|
||||
elif local_rank < ep_size:
|
||||
model.save_pretrained(ckpt_path)
|
||||
|
||||
print("Training complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user