mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
* Update README * Update `setup.py` * Fix headers * Add `DISABLE_NVSHMEM` for APIs * Fix launch * Fix TMA settings * Fix TMA usages * Fix dlink * Separate layout kernels * Update version * Add `is_sm90_compiled` * Fix tests * Add NVLink connection checks * Update README * Fix tests * Add some comments * Minor fix * Minor fix * Fix bugs
89 lines
3.3 KiB
Python
89 lines
3.3 KiB
Python
import os
|
|
import subprocess
|
|
import torch
|
|
import torch.distributed as dist
|
|
from typing import Any, Optional, Tuple
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
from deep_ep_cpp import Config, EventHandle
|
|
|
|
|
|
class EventOverlap:
|
|
"""
|
|
A wrapper class to manage CUDA events, also for better overlapping convenience.
|
|
|
|
Attributes:
|
|
event: the CUDA event captured.
|
|
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
|
|
"""
|
|
|
|
def __init__(self, event: Optional[EventHandle] = None,
|
|
extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
|
|
"""
|
|
Initialize the class.
|
|
|
|
Arguments:
|
|
event: the CUDA event captured.
|
|
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
|
|
"""
|
|
self.event = event
|
|
|
|
# NOTES: we use extra tensors to achieve stream recording, otherwise,
|
|
# stream recording will be incompatible with CUDA graph.
|
|
self.extra_tensors = extra_tensors
|
|
|
|
def current_stream_wait(self) -> None:
|
|
"""
|
|
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
|
|
"""
|
|
assert self.event is not None
|
|
self.event.current_stream_wait()
|
|
|
|
def __enter__(self) -> Any:
|
|
"""
|
|
Utility for overlapping and Python `with` syntax.
|
|
|
|
You can overlap the kernels on the current stream with the following example:
|
|
```python
|
|
event_overlap = event_after_all_to_all_kernels()
|
|
with event_overlap():
|
|
do_something_on_current_stream()
|
|
# After exiting the `with` scope, the current stream with wait the event to be finished.
|
|
```
|
|
"""
|
|
return self
|
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
"""
|
|
Utility for overlapping and Python `with` syntax.
|
|
|
|
Please follow the example in the `__enter__` function.
|
|
"""
|
|
if self.event is not None:
|
|
self.event.current_stream_wait()
|
|
|
|
|
|
def check_nvlink_connections(group: dist.ProcessGroup):
|
|
"""
|
|
Check NVLink connection between every pair of GPUs.
|
|
|
|
Arguments:
|
|
group: the communication group.
|
|
"""
|
|
# Check NVLink connection
|
|
# NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2
|
|
if 'PCIE' in torch.cuda.get_device_name():
|
|
assert group.size() <= 2, 'No NVLink connection between all GPUs'
|
|
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',')
|
|
physical_device_idx = int(devices[torch.cuda.current_device()])
|
|
physical_device_indices = [0, ] * group.size()
|
|
dist.all_gather_object(physical_device_indices, physical_device_idx, group)
|
|
|
|
# Get connection matrix from `nvidia-smi`
|
|
lines = subprocess.check_output(['nvidia-smi', 'topo', '-p2p', 'n']).decode('utf-8').split('\n')
|
|
for line in lines:
|
|
if line.lstrip().startswith(f'GPU{physical_device_idx}') and 'X' in line:
|
|
status = line.strip().lstrip(f'GPU{physical_device_idx}').split()
|
|
for dst_gpu_rank in physical_device_indices:
|
|
assert status[dst_gpu_rank] in ('X', 'OK'), f'No NVLink connection between GPU {physical_device_idx} and GPU {dst_gpu_rank}'
|