import torch from typing import Any, Optional, Tuple # noinspection PyUnresolvedReferences from deep_ep_cpp import Config, EventHandle class EventOverlap: """ A wrapper class to manage CUDA events, also for better overlapping convenience. Attributes: event: the CUDA event captured. extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. """ def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None: """ Initialize the class. Arguments: event: the CUDA event captured. extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. """ self.event = event # NOTES: we use extra tensors to achieve stream recording, otherwise, # stream recording will be incompatible with CUDA graph. self.extra_tensors = extra_tensors def current_stream_wait(self) -> None: """ The current stream `torch.cuda.current_stream()` waits for the event to be finished. """ assert self.event is not None self.event.current_stream_wait() def __enter__(self) -> Any: """ Utility for overlapping and Python `with` syntax. You can overlap the kernels on the current stream with the following example: ```python event_overlap = event_after_all_to_all_kernels() with event_overlap(): do_something_on_current_stream() # After exiting the `with` scope, the current stream with wait the event to be finished. ``` """ return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """ Utility for overlapping and Python `with` syntax. Please follow the example in the `__enter__` function. """ if self.event is not None: self.event.current_stream_wait()