diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 3fa069d..25fcff5 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -32,7 +32,8 @@ class Buffer: def __init__(self, group: dist.ProcessGroup, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, low_latency_mode: bool = False, num_qps_per_rank: int = 12, - allow_nvlink_for_low_latency_mode: bool = True) -> None: + allow_nvlink_for_low_latency_mode: bool = True, + allow_mnnvl: bool = False) -> None: """ Initialize the communication buffer. @@ -47,6 +48,7 @@ class Buffer: this is somehow incompatible with the hook-based overlapping. Warning: PCIe connections may lead to errors due to memory ordering issues, please make sure all connections are via NVLink. + allow_mnnvl: whether to allow MNNVL """ # Initialize the CPP runtime @@ -88,8 +90,9 @@ class Buffer: # NOTES: NVSHMEM initialization requires at least 256 MiB os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' - # Disable multi-node NVLink detection - os.environ['NVSHMEM_DISABLE_MNNVL'] = '1' + if not allow_mnnvl: + # Disable multi-node NVLink detection + os.environ['NVSHMEM_DISABLE_MNNVL'] = '1' # Synchronize using the root ID nvshmem_unique_ids = [None, ] * self.group_size