mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-16 19:29:14 +00:00
Support statistics tensor for low-latency kernels (#196)
This commit is contained in:
parent
0d1a855d81
commit
5a2e37fa28
@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
|||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||||
|
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||||
bool use_fp8, bool async, bool return_recv_hook) {
|
bool use_fp8, bool async, bool return_recv_hook) {
|
||||||
EP_HOST_ASSERT(low_latency_mode);
|
EP_HOST_ASSERT(low_latency_mode);
|
||||||
@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
|||||||
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
|
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
|
||||||
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
|
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
|
||||||
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||||
|
if (cumulative_local_expert_recv_stats.has_value()) {
|
||||||
|
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
|
||||||
|
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
|
||||||
|
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
|
||||||
|
}
|
||||||
|
|
||||||
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
||||||
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
||||||
@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
|||||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||||
packed_recv_count.data_ptr<int>(),
|
packed_recv_count.data_ptr<int>(),
|
||||||
|
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
|
||||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||||
buffer.dispatch_rdma_send_buffer,
|
buffer.dispatch_rdma_send_buffer,
|
||||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||||
|
@ -142,6 +142,7 @@ public:
|
|||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||||
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||||
|
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||||
bool use_fp8, bool async, bool return_recv_hook);
|
bool use_fp8, bool async, bool return_recv_hook);
|
||||||
|
|
||||||
|
@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
|||||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||||
int* packed_recv_count,
|
int* packed_recv_count,
|
||||||
|
int* cumulative_local_expert_recv_stats,
|
||||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||||
const void* x, const int64_t* topk_idx,
|
const void* x, const int64_t* topk_idx,
|
||||||
int* next_clean, int num_next_clean_int,
|
int* next_clean, int num_next_clean_int,
|
||||||
|
@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
|||||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||||
int* packed_recv_count,
|
int* packed_recv_count,
|
||||||
|
int* cumulative_local_expert_recv_stats,
|
||||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||||
const void* x, const int64_t* topk_idx,
|
const void* x, const int64_t* topk_idx,
|
||||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
|
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
|
||||||
@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|||||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||||
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
||||||
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
||||||
|
if (cumulative_local_expert_recv_stats != nullptr)
|
||||||
|
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
|
||||||
}
|
}
|
||||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
||||||
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
||||||
@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|||||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||||
int* packed_recv_count,
|
int* packed_recv_count,
|
||||||
|
int* cumulative_local_expert_recv_stats,
|
||||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||||
const void* x, const int64_t* topk_idx,
|
const void* x, const int64_t* topk_idx,
|
||||||
int* next_clean, int num_next_clean_int,
|
int* next_clean, int num_next_clean_int,
|
||||||
@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
|||||||
packed_recv_x, packed_recv_x_scales, \
|
packed_recv_x, packed_recv_x_scales, \
|
||||||
packed_recv_src_info, packed_recv_layout_range, \
|
packed_recv_src_info, packed_recv_layout_range, \
|
||||||
packed_recv_count, \
|
packed_recv_count, \
|
||||||
|
cumulative_local_expert_recv_stats, \
|
||||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||||
x, topk_idx, \
|
x, topk_idx, \
|
||||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||||
|
@ -473,6 +473,7 @@ class Buffer:
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||||
|
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||||||
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||||
"""
|
"""
|
||||||
@ -481,7 +482,7 @@ class Buffer:
|
|||||||
(specifically, IBGDA must be enabled).
|
(specifically, IBGDA must be enabled).
|
||||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||||
low-latency kernels' result tensor at a single moment.
|
low-latency kernels' result tensors at a single moment.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
|
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
|
||||||
@ -490,6 +491,9 @@ class Buffer:
|
|||||||
are supported. `-1` indices (not selecting any expert) are supported.
|
are supported. `-1` indices (not selecting any expert) are supported.
|
||||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||||
num_experts: the number of all experts.
|
num_experts: the number of all experts.
|
||||||
|
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
|
||||||
|
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
|
||||||
|
monitoring.
|
||||||
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
|
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
|
||||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||||
@ -508,19 +512,21 @@ class Buffer:
|
|||||||
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
||||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
|
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
|
||||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||||
expert receive. As mentioned before, not all tokens are valid in `recv_x`.
|
expert receives. As mentioned before, not all tokens are valid in `recv_x`.
|
||||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||||
"""
|
"""
|
||||||
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
||||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||||
|
cumulative_local_expert_recv_stats,
|
||||||
num_max_dispatch_tokens_per_rank, num_experts,
|
num_max_dispatch_tokens_per_rank, num_experts,
|
||||||
use_fp8, async_finish, return_recv_hook)
|
use_fp8, async_finish, return_recv_hook)
|
||||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
|
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
|
||||||
tensors_to_record = (x, topk_idx,
|
tensors_to_record = (x, topk_idx,
|
||||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||||
packed_recv_src_info, packed_recv_layout_range)
|
packed_recv_src_info, packed_recv_layout_range,
|
||||||
|
cumulative_local_expert_recv_stats)
|
||||||
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
|
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
|
||||||
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||||
|
|
||||||
|
@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
|||||||
for dispatch_use_fp8 in (False, True):
|
for dispatch_use_fp8 in (False, True):
|
||||||
num_times += 1
|
num_times += 1
|
||||||
for i in range((num_times % 2) + 1):
|
for i in range((num_times % 2) + 1):
|
||||||
|
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
|
||||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
||||||
|
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||||
hook() if return_recv_hook else event.current_stream_wait()
|
hook() if return_recv_hook else event.current_stream_wait()
|
||||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||||
@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
|||||||
# Check expert indices
|
# Check expert indices
|
||||||
int_mask = (2 ** 32) - 1
|
int_mask = (2 ** 32) - 1
|
||||||
num_valid_tokens = recv_count.item()
|
num_valid_tokens = recv_count.item()
|
||||||
|
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}'
|
||||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||||
|
|
||||||
@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
|||||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||||
recv_x, recv_count, handle, event, hook = \
|
recv_x, recv_count, handle, event, hook = \
|
||||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||||
|
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
|
||||||
async_finish=False, return_recv_hook=return_recv_hook)
|
async_finish=False, return_recv_hook=return_recv_hook)
|
||||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||||
if zero_copy:
|
if zero_copy:
|
||||||
|
Loading…
Reference in New Issue
Block a user