From 7a96df9be6e005013c04717d41370d566652301b Mon Sep 17 00:00:00 2001 From: Xingkai Yu <38156925+GeeeekExplorer@users.noreply.github.com> Date: Mon, 12 Aug 2024 10:28:35 +0800 Subject: [PATCH] Update train_ep.py --- train_ep.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train_ep.py b/train_ep.py index 3ba98e8..f26b568 100644 --- a/train_ep.py +++ b/train_ep.py @@ -126,7 +126,6 @@ def main(): backward(loss, **kwargs) if not self.sync_gradients or edp_size == 1: return - return for p in expert_params: g = p.grad if p.grad is not None else torch.zeros_like(p) dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group) @@ -145,10 +144,10 @@ def main(): if local_rank == 0: trainer.save_model(ckpt_path) tokenizer.save_pretrained(ckpt_path) - elif 0 < local_rank < ep_size: + elif local_rank < ep_size: model.save_pretrained(ckpt_path) print("Training complete") if __name__ == "__main__": - main() \ No newline at end of file + main()