Merge pull request #5 from GeeeekExplorer/main

Update train_ep.py
This commit is contained in:
Zihan Wang 2024-08-12 10:30:33 +08:00 committed by GitHub
commit 4e2defea82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,7 +126,6 @@ def main():
backward(loss, **kwargs)
if not self.sync_gradients or edp_size == 1:
return
return
for p in expert_params:
g = p.grad if p.grad is not None else torch.zeros_like(p)
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
@ -145,10 +144,10 @@ def main():
if local_rank == 0:
trainer.save_model(ckpt_path)
tokenizer.save_pretrained(ckpt_path)
elif 0 < local_rank < ep_size:
elif local_rank < ep_size:
model.save_pretrained(ckpt_path)
print("Training complete")
if __name__ == "__main__":
main()
main()