mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Allow passing output tensor in low_latency_combine
This commit is contained in:
parent
ed7487c15e
commit
b3b61ef5ef
@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
|||||||
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||||
bool async, bool return_recv_hook) {
|
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
|
||||||
EP_HOST_ASSERT(low_latency_mode);
|
EP_HOST_ASSERT(low_latency_mode);
|
||||||
|
|
||||||
// Tensor checks
|
// Tensor checks
|
||||||
@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
|||||||
stream_wait(launch_stream, compute_stream);
|
stream_wait(launch_stream, compute_stream);
|
||||||
|
|
||||||
// Allocate output tensor
|
// Allocate output tensor
|
||||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
torch::Tensor combined_x;
|
||||||
|
if (out.has_value()) {
|
||||||
|
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
|
||||||
|
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
|
||||||
|
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
|
||||||
|
combined_x = out.value();
|
||||||
|
} else {
|
||||||
|
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||||
|
}
|
||||||
|
|
||||||
// Kernel launch
|
// Kernel launch
|
||||||
auto next_clean_meta = next_buffer.clean_meta();
|
auto next_clean_meta = next_buffer.clean_meta();
|
||||||
|
@ -143,7 +143,7 @@ public:
|
|||||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||||
bool async, bool return_recv_hook);
|
bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace deep_ep
|
} // namespace deep_ep
|
||||||
|
@ -497,7 +497,8 @@ class Buffer:
|
|||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
|
||||||
|
out: torch.Tensor | None = None) -> \
|
||||||
Tuple[torch.Tensor, EventOverlap, Callable]:
|
Tuple[torch.Tensor, EventOverlap, Callable]:
|
||||||
"""
|
"""
|
||||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||||
@ -520,6 +521,7 @@ class Buffer:
|
|||||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||||
If you not set this flag, the kernel will ensure the data's arrival.
|
If you not set this flag, the kernel will ensure the data's arrival.
|
||||||
|
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
|
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
|
||||||
@ -529,6 +531,6 @@ class Buffer:
|
|||||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||||
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
||||||
num_max_dispatch_tokens_per_rank, num_experts,
|
num_max_dispatch_tokens_per_rank, num_experts,
|
||||||
async_finish, return_recv_hook)
|
async_finish, return_recv_hook, out)
|
||||||
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
||||||
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||||
|
@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
|||||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||||
|
|
||||||
# Check combine correctness
|
# Check combine correctness
|
||||||
|
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
|
||||||
hook() if return_recv_hook else event.current_stream_wait()
|
hook() if return_recv_hook else event.current_stream_wait()
|
||||||
if do_check:
|
if do_check:
|
||||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||||
|
Loading…
Reference in New Issue
Block a user