mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-06-26 18:16:46 +00:00
Initial commit
This commit is contained in:
80
dualpipe/utils.py
Normal file
80
dualpipe/utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import queue
|
||||
from typing import List, Callable
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
enabled: bool = False
|
||||
cache: List[Callable] = []
|
||||
funcs_queue = queue.Queue()
|
||||
|
||||
@classmethod
|
||||
def put(cls, func: Callable) -> None:
|
||||
cls.cache.append(func)
|
||||
|
||||
@classmethod
|
||||
def flush(cls) -> None:
|
||||
cls.funcs_queue.put(cls.cache)
|
||||
cls.cache = []
|
||||
|
||||
@classmethod
|
||||
def pop(cls) -> None:
|
||||
assert not cls.funcs_queue.empty(), "Pop empty queue."
|
||||
funcs = cls.funcs_queue.get()
|
||||
for func in funcs:
|
||||
func()
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls.cache = []
|
||||
cls.funcs_queue = queue.Queue()
|
||||
|
||||
|
||||
def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None:
|
||||
kwargs = dict(
|
||||
keep_graph=False,
|
||||
create_graph=False,
|
||||
allow_unreachable=True,
|
||||
accumulate_grad=True,
|
||||
)
|
||||
Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs)
|
||||
|
||||
|
||||
def chunk_tensor(x, chunks, dim):
|
||||
if x is None:
|
||||
return [None for _ in range(chunks)]
|
||||
return x.tensor_split(chunks, dim=dim)
|
||||
|
||||
|
||||
def cat_tensor(x, dim):
|
||||
if (isinstance(x, tuple) or isinstance(x, list)):
|
||||
if len(x) == 1:
|
||||
return x[0]
|
||||
elif x[0] is None:
|
||||
assert all(y is None for y in x)
|
||||
return None
|
||||
return torch.cat(x, dim=dim)
|
||||
|
||||
|
||||
def scatter(inputs, chunks, dim):
|
||||
assert isinstance(inputs, (torch.Tensor, tuple, list))
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = (inputs,)
|
||||
assert all(x is None or isinstance(x, torch.Tensor) for x in inputs)
|
||||
inputs = [chunk_tensor(x, chunks, dim) for x in inputs]
|
||||
microbatches = [microbatch for microbatch in zip(*inputs)]
|
||||
if len(microbatches) == 0:
|
||||
microbatches = [() for _ in range(chunks)]
|
||||
return microbatches
|
||||
|
||||
|
||||
def gather(micro_outputs, dim):
|
||||
assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list))
|
||||
if isinstance(micro_outputs[0], torch.Tensor):
|
||||
micro_outputs = [(x,) for x in micro_outputs]
|
||||
outputs = [x for x in zip(*micro_outputs)]
|
||||
outputs = tuple(cat_tensor(x, dim=dim) for x in outputs)
|
||||
return outputs
|
||||
Reference in New Issue
Block a user