mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Initial commit
This commit is contained in:
60
deep_ep/utils.py
Normal file
60
deep_ep/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from deep_ep_cpp import Config, EventHandle
|
||||
|
||||
|
||||
class EventOverlap:
|
||||
"""
|
||||
A wrapper class to manage CUDA events, also for better overlapping convenience.
|
||||
|
||||
Attributes:
|
||||
event: the CUDA event captured.
|
||||
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
|
||||
"""
|
||||
|
||||
def __init__(self, event: Optional[EventHandle] = None,
|
||||
extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
|
||||
"""
|
||||
Initialize the class.
|
||||
|
||||
Arguments:
|
||||
event: the CUDA event captured.
|
||||
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
|
||||
"""
|
||||
self.event = event
|
||||
|
||||
# NOTES: we use extra tensors to achieve stream recording, otherwise,
|
||||
# stream recording will be incompatible with CUDA graph.
|
||||
self.extra_tensors = extra_tensors
|
||||
|
||||
def current_stream_wait(self) -> None:
|
||||
"""
|
||||
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
|
||||
"""
|
||||
assert self.event is not None
|
||||
self.event.current_stream_wait()
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
"""
|
||||
Utility for overlapping and Python `with` syntax.
|
||||
|
||||
You can overlap the kernels on the current stream with the following example:
|
||||
```python
|
||||
event_overlap = event_after_all_to_all_kernels()
|
||||
with event_overlap():
|
||||
do_something_on_current_stream()
|
||||
# After exiting the `with` scope, the current stream with wait the event to be finished.
|
||||
```
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""
|
||||
Utility for overlapping and Python `with` syntax.
|
||||
|
||||
Please follow the example in the `__enter__` function.
|
||||
"""
|
||||
if self.event is not None:
|
||||
self.event.current_stream_wait()
|
||||
Reference in New Issue
Block a user