Add get_comm_stream. (#256)

* Add `get_comm_stream`.

* Fix style.
This commit is contained in:
Shangyan Zhou
2025-06-25 13:02:13 +08:00
parent 9eb2f84b3e
commit 85adc566e2
3 changed files with 17 additions and 0 deletions

View File

@@ -147,6 +147,16 @@ class Buffer:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
def get_comm_stream(self) -> torch.Stream:
"""
Get the communication stream.
Returns:
stream: the communication stream.
"""
ts: torch.Stream = self.runtime.get_comm_stream()
return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type)
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: