mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
parent
a15faa9ff0
commit
b80e55e21f
@ -163,6 +163,10 @@ torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int
|
||||
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
|
||||
}
|
||||
|
||||
torch::Stream Buffer::get_comm_stream() const {
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void Buffer::sync(const std::vector<int> &device_ids,
|
||||
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
|
||||
const std::optional<pybind11::bytearray>& root_unique_id_opt) {
|
||||
@ -1303,6 +1307,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
|
||||
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
|
||||
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
|
||||
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
|
||||
.def("sync", &deep_ep::Buffer::sync)
|
||||
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
|
||||
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)
|
||||
|
@ -94,6 +94,8 @@ public:
|
||||
|
||||
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
|
||||
|
||||
torch::Stream get_comm_stream() const;
|
||||
|
||||
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
|
||||
|
@ -147,6 +147,16 @@ class Buffer:
|
||||
size: the RDMA buffer size recommended.
|
||||
"""
|
||||
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
|
||||
|
||||
def get_comm_stream(self) -> torch.Stream:
|
||||
"""
|
||||
Get the communication stream.
|
||||
|
||||
Returns:
|
||||
stream: the communication stream.
|
||||
"""
|
||||
ts: torch.Stream = self.runtime.get_comm_stream()
|
||||
return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type)
|
||||
|
||||
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
|
||||
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor:
|
||||
|
Loading…
Reference in New Issue
Block a user