Merge pull request #66 from dzhulgakov/combine-out-arg

Allow passing output tensor in low_latency_combine
This commit is contained in:
Chenggang Zhao 2025-03-13 09:18:06 +08:00 committed by GitHub
commit 7128ba3e39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 6 deletions

View File

@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// Tensor checks // Tensor checks
@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
// Allocate output tensor // Allocate output tensor
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); torch::Tensor combined_x;
if (out.has_value()) {
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
combined_x = out.value();
} else {
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
}
// Kernel launch // Kernel launch
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();

View File

@ -143,7 +143,7 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook); bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
}; };
} // namespace deep_ep } // namespace deep_ep

View File

@ -497,7 +497,8 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \ handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
out: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@ -520,6 +521,7 @@ class Buffer:
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns: Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
@ -529,6 +531,6 @@ class Buffer:
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook) async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook

View File

@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)