Update train_ep.py

This commit is contained in:
Xingkai Yu 2024-08-12 10:28:35 +08:00 committed by GitHub
parent 4bfa99486b
commit 7a96df9be6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,7 +126,6 @@ def main():
backward(loss, **kwargs)
if not self.sync_gradients or edp_size == 1:
return
return
for p in expert_params:
g = p.grad if p.grad is not None else torch.zeros_like(p)
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
@ -145,10 +144,10 @@ def main():
if local_rank == 0:
trainer.save_model(ckpt_path)
tokenizer.save_pretrained(ckpt_path)
elif 0 < local_rank < ep_size:
elif local_rank < ep_size:
model.save_pretrained(ckpt_path)
print("Training complete")
if __name__ == "__main__":
main()
main()