mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-04-05 04:59:15 +00:00
81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
import queue
|
|
from typing import List, Callable
|
|
|
|
import torch
|
|
from torch.autograd import Variable
|
|
|
|
|
|
class WeightGradStore:
|
|
|
|
enabled: bool = False
|
|
cache: List[Callable] = []
|
|
funcs_queue = queue.Queue()
|
|
|
|
@classmethod
|
|
def put(cls, func: Callable) -> None:
|
|
cls.cache.append(func)
|
|
|
|
@classmethod
|
|
def flush(cls) -> None:
|
|
cls.funcs_queue.put(cls.cache)
|
|
cls.cache = []
|
|
|
|
@classmethod
|
|
def pop(cls) -> None:
|
|
assert not cls.funcs_queue.empty(), "Pop empty queue."
|
|
funcs = cls.funcs_queue.get()
|
|
for func in funcs:
|
|
func()
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls.cache = []
|
|
cls.funcs_queue = queue.Queue()
|
|
|
|
|
|
def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None:
|
|
kwargs = dict(
|
|
keep_graph=False,
|
|
create_graph=False,
|
|
allow_unreachable=True,
|
|
accumulate_grad=True,
|
|
)
|
|
Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs)
|
|
|
|
|
|
def chunk_tensor(x, chunks, dim):
|
|
if x is None:
|
|
return [None for _ in range(chunks)]
|
|
return x.tensor_split(chunks, dim=dim)
|
|
|
|
|
|
def cat_tensor(x, dim):
|
|
if (isinstance(x, tuple) or isinstance(x, list)):
|
|
if len(x) == 1:
|
|
return x[0]
|
|
elif x[0] is None:
|
|
assert all(y is None for y in x)
|
|
return None
|
|
return torch.cat(x, dim=dim)
|
|
|
|
|
|
def scatter(inputs, chunks, dim):
|
|
assert isinstance(inputs, (torch.Tensor, tuple, list))
|
|
if isinstance(inputs, torch.Tensor):
|
|
inputs = (inputs,)
|
|
assert all(x is None or isinstance(x, torch.Tensor) for x in inputs)
|
|
inputs = [chunk_tensor(x, chunks, dim) for x in inputs]
|
|
microbatches = [microbatch for microbatch in zip(*inputs)]
|
|
if len(microbatches) == 0:
|
|
microbatches = [() for _ in range(chunks)]
|
|
return microbatches
|
|
|
|
|
|
def gather(micro_outputs, dim):
|
|
assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list))
|
|
if isinstance(micro_outputs[0], torch.Tensor):
|
|
micro_outputs = [(x,) for x in micro_outputs]
|
|
outputs = [x for x in zip(*micro_outputs)]
|
|
outputs = tuple(cat_tensor(x, dim=dim) for x in outputs)
|
|
return outputs
|