diff --git a/train_ep.py b/train_ep.py index 3ba98e8..f26b568 100644 --- a/train_ep.py +++ b/train_ep.py @@ -126,7 +126,6 @@ def main(): backward(loss, **kwargs) if not self.sync_gradients or edp_size == 1: return - return for p in expert_params: g = p.grad if p.grad is not None else torch.zeros_like(p) dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group) @@ -145,10 +144,10 @@ def main(): if local_rank == 0: trainer.save_model(ckpt_path) tokenizer.save_pretrained(ckpt_path) - elif 0 < local_rank < ep_size: + elif local_rank < ep_size: model.save_pretrained(ckpt_path) print("Training complete") if __name__ == "__main__": - main() \ No newline at end of file + main()