mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-25 21:27:57 +00:00
commit
4e2defea82
@ -126,7 +126,6 @@ def main():
|
|||||||
backward(loss, **kwargs)
|
backward(loss, **kwargs)
|
||||||
if not self.sync_gradients or edp_size == 1:
|
if not self.sync_gradients or edp_size == 1:
|
||||||
return
|
return
|
||||||
return
|
|
||||||
for p in expert_params:
|
for p in expert_params:
|
||||||
g = p.grad if p.grad is not None else torch.zeros_like(p)
|
g = p.grad if p.grad is not None else torch.zeros_like(p)
|
||||||
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
|
dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group)
|
||||||
@ -145,7 +144,7 @@ def main():
|
|||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
trainer.save_model(ckpt_path)
|
trainer.save_model(ckpt_path)
|
||||||
tokenizer.save_pretrained(ckpt_path)
|
tokenizer.save_pretrained(ckpt_path)
|
||||||
elif 0 < local_rank < ep_size:
|
elif local_rank < ep_size:
|
||||||
model.save_pretrained(ckpt_path)
|
model.save_pretrained(ckpt_path)
|
||||||
|
|
||||||
print("Training complete")
|
print("Training complete")
|
||||||
|
Loading…
Reference in New Issue
Block a user