mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support BF16 for low-latency kernels
This commit is contained in:
@@ -444,10 +444,10 @@ class Buffer:
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||
async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for dispatching with IBGDA **with implicit FP8 casting**.
|
||||
A low-latency implementation for dispatching with IBGDA.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
@@ -461,19 +461,23 @@ class Buffer:
|
||||
are supported. `-1` indices (not selecting any expert) are supported.
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
num_experts: the number of all experts.
|
||||
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
|
||||
recv_x: a tensor or tuple with received tokens for each expert.
|
||||
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
|
||||
The second tensor is the corresponding scales for the first element with shape
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
|
||||
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
|
||||
With `use_fp8=False`, the result would be a tensor shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
|
||||
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
@@ -483,12 +487,12 @@ class Buffer:
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
use_fp8, async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
packed_recv_src_info, packed_recv_layout_range)
|
||||
return (packed_recv_x, packed_recv_x_scales), packed_recv_count, handle, \
|
||||
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
|
||||
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
|
||||
Reference in New Issue
Block a user