mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-15 18:58:17 +00:00
Support statistics tensor for low-latency kernels (#196)
This commit is contained in:
parent
0d1a855d81
commit
5a2e37fa28
@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
if (cumulative_local_expert_recv_stats.has_value()) {
|
||||
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
|
||||
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
|
||||
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
|
||||
}
|
||||
|
||||
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
||||
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
||||
@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||
packed_recv_count.data_ptr<int>(),
|
||||
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
|
||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||
buffer.dispatch_rdma_send_buffer,
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
|
@ -142,6 +142,7 @@ public:
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook);
|
||||
|
||||
|
@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
|
@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
|
||||
@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
||||
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
||||
if (cumulative_local_expert_recv_stats != nullptr)
|
||||
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
||||
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
||||
@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
cumulative_local_expert_recv_stats, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
|
@ -473,6 +473,7 @@ class Buffer:
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||||
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||
"""
|
||||
@ -481,7 +482,7 @@ class Buffer:
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
low-latency kernels' result tensor at a single moment.
|
||||
low-latency kernels' result tensors at a single moment.
|
||||
|
||||
Arguments:
|
||||
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
|
||||
@ -490,6 +491,9 @@ class Buffer:
|
||||
are supported. `-1` indices (not selecting any expert) are supported.
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
num_experts: the number of all experts.
|
||||
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
|
||||
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
|
||||
monitoring.
|
||||
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
@ -508,19 +512,21 @@ class Buffer:
|
||||
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, not all tokens are valid in `recv_x`.
|
||||
expert receives. As mentioned before, not all tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
cumulative_local_expert_recv_stats,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
use_fp8, async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
packed_recv_src_info, packed_recv_layout_range)
|
||||
packed_recv_src_info, packed_recv_layout_range,
|
||||
cumulative_local_expert_recv_stats)
|
||||
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
|
||||
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
|
@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
|
||||
@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
|
Loading…
Reference in New Issue
Block a user