Add get_comm_stream. (#256)

* Add `get_comm_stream`.

* Fix style.
This commit is contained in:
Shangyan Zhou 2025-06-25 13:02:13 +08:00 committed by GitHub
parent a15faa9ff0
commit b80e55e21f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 0 deletions

View File

@ -163,6 +163,10 @@ torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}
torch::Stream Buffer::get_comm_stream() const {
return comm_stream;
}
void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt) {
@ -1303,6 +1307,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
.def("sync", &deep_ep::Buffer::sync)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)

View File

@ -94,6 +94,8 @@ public:
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
torch::Stream get_comm_stream() const;
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>

View File

@ -147,6 +147,16 @@ class Buffer:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
def get_comm_stream(self) -> torch.Stream:
"""
Get the communication stream.
Returns:
stream: the communication stream.
"""
ts: torch.Stream = self.runtime.get_comm_stream()
return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type)
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: