Merge pull request #5 from GeeeekExplorer/main

Update train_ep.py
This commit is contained in:
Zihan Wang 2024-08-12 10:30:33 +08:00 committed by GitHub
commit 4e2defea82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,7 +126,6 @@ def main():
backward(loss, **kwargs) backward(loss, **kwargs)
if not self.sync_gradients or edp_size == 1: if not self.sync_gradients or edp_size == 1:
return return
return
for p in expert_params: for p in expert_params:
g = p.grad if p.grad is not None else torch.zeros_like(p) g = p.grad if p.grad is not None else torch.zeros_like(p)
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group) dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
@ -145,7 +144,7 @@ def main():
if local_rank == 0: if local_rank == 0:
trainer.save_model(ckpt_path) trainer.save_model(ckpt_path)
tokenizer.save_pretrained(ckpt_path) tokenizer.save_pretrained(ckpt_path)
elif 0 < local_rank < ep_size: elif local_rank < ep_size:
model.save_pretrained(ckpt_path) model.save_pretrained(ckpt_path)
print("Training complete") print("Training complete")