DualPipe/dualpipe/utils.py
2025-02-27 10:12:10 +08:00

81 lines
2.2 KiB
Python

import queue
from typing import List, Callable
import torch
from torch.autograd import Variable
class WeightGradStore:
enabled: bool = False
cache: List[Callable] = []
funcs_queue = queue.Queue()
@classmethod
def put(cls, func: Callable) -> None:
cls.cache.append(func)
@classmethod
def flush(cls) -> None:
cls.funcs_queue.put(cls.cache)
cls.cache = []
@classmethod
def pop(cls) -> None:
assert not cls.funcs_queue.empty(), "Pop empty queue."
funcs = cls.funcs_queue.get()
for func in funcs:
func()
@classmethod
def clear(cls) -> None:
cls.cache = []
cls.funcs_queue = queue.Queue()
def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None:
kwargs = dict(
keep_graph=False,
create_graph=False,
allow_unreachable=True,
accumulate_grad=True,
)
Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs)
def chunk_tensor(x, chunks, dim):
if x is None:
return [None for _ in range(chunks)]
return x.tensor_split(chunks, dim=dim)
def cat_tensor(x, dim):
if (isinstance(x, tuple) or isinstance(x, list)):
if len(x) == 1:
return x[0]
elif x[0] is None:
assert all(y is None for y in x)
return None
return torch.cat(x, dim=dim)
def scatter(inputs, chunks, dim):
assert isinstance(inputs, (torch.Tensor, tuple, list))
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
assert all(x is None or isinstance(x, torch.Tensor) for x in inputs)
inputs = [chunk_tensor(x, chunks, dim) for x in inputs]
microbatches = [microbatch for microbatch in zip(*inputs)]
if len(microbatches) == 0:
microbatches = [() for _ in range(chunks)]
return microbatches
def gather(micro_outputs, dim):
assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list))
if isinstance(micro_outputs[0], torch.Tensor):
micro_outputs = [(x,) for x in micro_outputs]
outputs = [x for x in zip(*micro_outputs)]
outputs = tuple(cat_tensor(x, dim=dim) for x in outputs)
return outputs