mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-04-29 18:31:25 +00:00
Remove useless control metadata for low-latency combine
This commit is contained in:
parent
2a0b3d7a5d
commit
42494864ba
@ -122,7 +122,6 @@ struct LowLatencyLayout {
|
||||
|
||||
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
|
||||
const int num_scales = hidden / 128;
|
||||
const int num_local_experts = num_experts / num_ranks;
|
||||
|
||||
// Dispatch and combine layout:
|
||||
// - 2 symmetric odd/even send buffer
|
||||
@ -130,9 +129,10 @@ struct LowLatencyLayout {
|
||||
// - 2 symmetric odd/even signaling buffers
|
||||
|
||||
// Message sizes
|
||||
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation
|
||||
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
|
||||
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
|
||||
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
|
||||
size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16);
|
||||
|
||||
// Send buffer
|
||||
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
|
||||
@ -167,7 +167,7 @@ struct LowLatencyLayout {
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
num_bytes_per_combine_msg
|
||||
};
|
||||
}
|
||||
|
@ -369,8 +369,7 @@ combine(void* combined_x,
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
|
||||
|
||||
// Message package
|
||||
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
|
||||
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16);
|
||||
constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16);
|
||||
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
|
||||
|
||||
// Sending phase
|
||||
@ -409,12 +408,12 @@ combine(void* combined_x,
|
||||
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
|
||||
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
|
||||
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
|
||||
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
|
||||
|
||||
// Copy directly to local rank, or copy to buffer and issue RDMA
|
||||
auto src_idx = __ldg(local_src_info + token_idx);
|
||||
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
|
||||
if (dst_rank == rank) {
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
@ -473,7 +472,7 @@ combine(void* combined_x,
|
||||
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
|
||||
// Read from sources
|
||||
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
|
||||
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
|
||||
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
|
||||
|
||||
// Reduce
|
||||
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
|
||||
|
@ -84,7 +84,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: diff={diff}'
|
||||
assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
|
Loading…
Reference in New Issue
Block a user