Support UE8M0 data format. (#206)

This commit is contained in:
Shifang Xu
2025-06-12 09:38:19 +08:00
committed by GitHub
parent 9ec061204e
commit 21efbe9b48
14 changed files with 255 additions and 115 deletions

View File

@@ -178,6 +178,7 @@ class Buffer:
config: the recommended config.
"""
# TODO: automatically tune
config_map = {
2: Config(Buffer.num_sms, 24, 256, 6, 128),
4: Config(Buffer.num_sms, 6, 256, 6, 128),
@@ -205,6 +206,7 @@ class Buffer:
config: the recommended config.
"""
# TODO: automatically tune
config_map = {
2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 9, 256, 6, 128),
@@ -486,14 +488,14 @@ class Buffer:
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
low-latency kernels' result tensors at a single moment.
Arguments:
@@ -507,17 +509,21 @@ class Buffer:
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
@@ -533,7 +539,8 @@ class Buffer:
self.runtime.low_latency_dispatch(x, topk_idx,
cumulative_local_expert_recv_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, async_finish, return_recv_hook)
use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
@@ -551,9 +558,8 @@ class Buffer:
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
low-latency kernels' result tensors at a single moment.
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
@@ -569,7 +575,7 @@ class Buffer:
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns: