Remove useless control metadata for low-latency combine

This commit is contained in:
Chenggang Zhao 2025-04-07 09:55:39 +08:00
parent 2a0b3d7a5d
commit 42494864ba
3 changed files with 8 additions and 9 deletions

View File

@ -122,7 +122,6 @@ struct LowLatencyLayout {
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
const int num_local_experts = num_experts / num_ranks;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
@ -130,9 +129,10 @@ struct LowLatencyLayout {
// - 2 symmetric odd/even signaling buffers
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
@ -167,7 +167,7 @@ struct LowLatencyLayout {
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
advance(rdma_buffer, send_buffer_bytes * i),
num_bytes_per_combine_msg
};
}

View File

@ -369,8 +369,7 @@ combine(void* combined_x,
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16);
constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// Sending phase
@ -409,12 +408,12 @@ combine(void* combined_x,
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
@ -473,7 +472,7 @@ combine(void* combined_x,
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);

View File

@ -84,7 +84,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f'Error: diff={diff}'
assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):