mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support BF16 for low-latency kernels
This commit is contained in:
parent
1fc40d50f3
commit
ed7487c15e
@ -282,7 +282,7 @@ For two micro-batch overlapping, you can refer to the following figure. With our
|
||||
- [x] AR support
|
||||
- [ ] Refactor low-latency mode AR code
|
||||
- [ ] A100 support (intranode only)
|
||||
- [ ] Support BF16 for the low-latency dispatch kernel
|
||||
- [x] Support BF16 for the low-latency dispatch kernel
|
||||
- [ ] Support NVLink protocol for intranode low-latency kernels
|
||||
- [ ] SM-free normal kernels
|
||||
|
||||
|
@ -128,7 +128,7 @@ struct LowLatencyLayout {
|
||||
|
||||
// Message sizes
|
||||
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
|
||||
size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4);
|
||||
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
|
||||
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
|
||||
|
||||
// Send buffer
|
||||
|
@ -1011,10 +1011,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook) {
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@ -1045,20 +1045,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
// Allocate packed tensors
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
|
||||
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
|
||||
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
// Allocate column-majored scales
|
||||
auto packed_recv_x_scales = std::optional<torch::Tensor>();
|
||||
float* packed_recv_x_scales_ptr = nullptr;
|
||||
if (use_fp8) {
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
auto packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales, 1, 2);
|
||||
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
|
||||
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||
packed_recv_count.data_ptr<int>(),
|
||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||
@ -1066,7 +1072,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
num_topk, num_experts, rank, num_ranks, use_fp8,
|
||||
workspace, launch_stream, phases);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
@ -134,10 +134,10 @@ public:
|
||||
|
||||
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook);
|
||||
bool use_fp8, bool async, bool return_recv_hook);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
|
@ -137,7 +137,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
|
||||
void combine(void* combined_x,
|
||||
|
@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||
}
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
@ -62,11 +62,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_int4 = kHidden / sizeof(int4);
|
||||
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
|
||||
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
|
||||
|
||||
// Message package: hidden data, FP8 scales, index at source
|
||||
// NOTES: currently we have 3 reserved int fields for future use
|
||||
const size_t num_bytes_per_msg = kHidden + num_scales * sizeof(float) + sizeof(int4);
|
||||
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
|
||||
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
|
||||
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
|
||||
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
|
||||
|
||||
@ -89,9 +91,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
||||
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden);
|
||||
const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales);
|
||||
const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||
|
||||
// Overlap top-k index read and source token index write
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
@ -100,8 +102,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// FP8 cast
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
|
||||
// Read and calculate local amax
|
||||
// Read
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
|
||||
if (kUseFP8) {
|
||||
// Calculate local amax
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
float amax = kFP8Margin, scale, scale_inv;
|
||||
@ -118,14 +123,18 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
// Cast into send buffer
|
||||
int2 int2_value;
|
||||
vec_t int2_value;
|
||||
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; j += 2) {
|
||||
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
|
||||
fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
rdma_x_int2[i] = int2_value;
|
||||
rdma_x_vec[i] = int2_value;
|
||||
} else {
|
||||
// Reinterpret-cast is for C++14 compatibility
|
||||
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
|
||||
}
|
||||
}
|
||||
asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
|
||||
|
||||
@ -135,7 +144,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
|
||||
const auto dst_rank = dst_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
|
||||
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
@ -273,26 +282,28 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Copy tokens
|
||||
EP_DEVICE_ASSERT(num_scales <= 64);
|
||||
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
|
||||
// Copy source info
|
||||
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
||||
if (lane_id == 0)
|
||||
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
|
||||
__syncwarp();
|
||||
|
||||
// Copy data
|
||||
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
|
||||
const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
||||
const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global);
|
||||
const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
|
||||
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
|
||||
|
||||
// Copy scales
|
||||
const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden);
|
||||
if (kUseFP8) {
|
||||
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
|
||||
// Copy source info
|
||||
const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales);
|
||||
if (lane_id == 0)
|
||||
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -304,7 +315,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
@ -314,15 +325,16 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
|
||||
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
@ -331,7 +343,7 @@ LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases); break
|
||||
num_topk, num_experts, rank, num_ranks, phases); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
|
@ -444,10 +444,10 @@ class Buffer:
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||
async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for dispatching with IBGDA **with implicit FP8 casting**.
|
||||
A low-latency implementation for dispatching with IBGDA.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
@ -461,19 +461,23 @@ class Buffer:
|
||||
are supported. `-1` indices (not selecting any expert) are supported.
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
num_experts: the number of all experts.
|
||||
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
|
||||
recv_x: a tensor or tuple with received tokens for each expert.
|
||||
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
|
||||
The second tensor is the corresponding scales for the first element with shape
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
|
||||
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
|
||||
With `use_fp8=False`, the result would be a tensor shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
|
||||
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
@ -483,12 +487,12 @@ class Buffer:
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
use_fp8, async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
packed_recv_src_info, packed_recv_layout_range)
|
||||
return (packed_recv_x, packed_recv_x_scales), packed_recv_count, handle, \
|
||||
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
|
||||
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
|
@ -33,19 +33,21 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
do_check = True
|
||||
hash_value, num_times = 0, 0
|
||||
for return_recv_hook in (False, True):
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
|
||||
if dispatch_use_fp8 else packed_recv_x.clone()
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
|
||||
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
|
||||
|
||||
# Check expert indices
|
||||
@ -64,8 +66,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
|
||||
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
|
Loading…
Reference in New Issue
Block a user