DeepEP/deep_ep/utils.py
2025-02-25 09:07:53 +08:00

61 lines
2.0 KiB
Python

import torch
from typing import Any, Optional, Tuple
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config, EventHandle
class EventOverlap:
"""
A wrapper class to manage CUDA events, also for better overlapping convenience.
Attributes:
event: the CUDA event captured.
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
"""
def __init__(self, event: Optional[EventHandle] = None,
extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
"""
Initialize the class.
Arguments:
event: the CUDA event captured.
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
"""
self.event = event
# NOTES: we use extra tensors to achieve stream recording, otherwise,
# stream recording will be incompatible with CUDA graph.
self.extra_tensors = extra_tensors
def current_stream_wait(self) -> None:
"""
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
"""
assert self.event is not None
self.event.current_stream_wait()
def __enter__(self) -> Any:
"""
Utility for overlapping and Python `with` syntax.
You can overlap the kernels on the current stream with the following example:
```python
event_overlap = event_after_all_to_all_kernels()
with event_overlap():
do_something_on_current_stream()
# After exiting the `with` scope, the current stream with wait the event to be finished.
```
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""
Utility for overlapping and Python `with` syntax.
Please follow the example in the `__enter__` function.
"""
if self.event is not None:
self.event.current_stream_wait()