From b80e55e21f6c06f7816462d4ee50084fd7763298 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Wed, 25 Jun 2025 13:02:13 +0800 Subject: [PATCH] Add `get_comm_stream`. (#256) * Add `get_comm_stream`. * Fix style. --- csrc/deep_ep.cpp | 5 +++++ csrc/deep_ep.hpp | 2 ++ deep_ep/buffer.py | 10 ++++++++++ 3 files changed, 17 insertions(+) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 75906f6..d09e52e 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -163,6 +163,10 @@ torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } +torch::Stream Buffer::get_comm_stream() const { + return comm_stream; +} + void Buffer::sync(const std::vector &device_ids, const std::vector> &all_gathered_handles, const std::optional& root_unique_id_opt) { @@ -1303,6 +1307,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) + .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) .def("sync", &deep_ep::Buffer::sync) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index dfa2202..d96d726 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -94,6 +94,8 @@ public: torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; + torch::Stream get_comm_stream() const; + void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); std::tuple, torch::Tensor, torch::Tensor, std::optional> diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 36f3c54..3738aba 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -147,6 +147,16 @@ class Buffer: size: the RDMA buffer size recommended. """ return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) + + def get_comm_stream(self) -> torch.Stream: + """ + Get the communication stream. + + Returns: + stream: the communication stream. + """ + ts: torch.Stream = self.runtime.get_comm_stream() + return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type) def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None, offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: