mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Remove the low-latency usage flag
This commit is contained in:
parent
1b92be8a71
commit
74f4ef7b22
@ -78,13 +78,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
|||||||
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));
|
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));
|
||||||
*moe_recv_rdma_counter = -1;
|
*moe_recv_rdma_counter = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Low-latency kernels' usage flag
|
|
||||||
if (low_latency_mode) {
|
|
||||||
CUDA_CHECK(cudaMallocHost(&low_latency_usage_flag, sizeof(int), cudaHostAllocMapped));
|
|
||||||
CUDA_CHECK(cudaHostGetDevicePointer(&low_latency_usage_flag_mapped, const_cast<int*>(low_latency_usage_flag), 0));
|
|
||||||
*low_latency_usage_flag = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer::~Buffer() noexcept(false) {
|
Buffer::~Buffer() noexcept(false) {
|
||||||
@ -1028,16 +1021,6 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t Buffer::get_low_latency_usage_flag() const {
|
|
||||||
#ifndef DISABLE_NVSHMEM
|
|
||||||
EP_HOST_ASSERT(low_latency_usage_flag != nullptr);
|
|
||||||
return reinterpret_cast<uint64_t>(low_latency_usage_flag);
|
|
||||||
#else
|
|
||||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
|
||||||
return 0;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||||
#ifndef DISABLE_NVSHMEM
|
#ifndef DISABLE_NVSHMEM
|
||||||
EP_HOST_ASSERT(low_latency_mode);
|
EP_HOST_ASSERT(low_latency_mode);
|
||||||
@ -1143,9 +1126,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
|||||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||||
num_topk, num_experts, rank, num_ranks,
|
num_topk, num_experts, rank, num_ranks,
|
||||||
use_fp8, round_scale, use_ue8m0,
|
use_fp8, round_scale, use_ue8m0,
|
||||||
workspace, low_latency_usage_flag_mapped,
|
workspace, num_device_sms,
|
||||||
num_device_sms, launch_stream,
|
launch_stream, phases);
|
||||||
phases);
|
|
||||||
};
|
};
|
||||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||||
|
|
||||||
@ -1237,9 +1219,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
|||||||
next_clean_meta.first, next_clean_meta.second,
|
next_clean_meta.first, next_clean_meta.second,
|
||||||
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||||
num_topk, num_experts, rank, num_ranks,
|
num_topk, num_experts, rank, num_ranks,
|
||||||
workspace, low_latency_usage_flag_mapped,
|
workspace, num_device_sms,
|
||||||
num_device_sms, launch_stream,
|
launch_stream, phases, zero_copy);
|
||||||
phases, zero_copy);
|
|
||||||
};
|
};
|
||||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||||
|
|
||||||
@ -1328,7 +1309,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
|
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
|
||||||
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
|
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
|
||||||
.def("internode_combine", &deep_ep::Buffer::internode_combine)
|
.def("internode_combine", &deep_ep::Buffer::internode_combine)
|
||||||
.def("get_low_latency_usage_flag", &deep_ep::Buffer::get_low_latency_usage_flag)
|
|
||||||
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
|
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
|
||||||
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
|
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
|
||||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
|
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
|
||||||
|
@ -71,10 +71,6 @@ private:
|
|||||||
volatile int* moe_recv_rdma_counter = nullptr;
|
volatile int* moe_recv_rdma_counter = nullptr;
|
||||||
int* moe_recv_rdma_counter_mapped = nullptr;
|
int* moe_recv_rdma_counter_mapped = nullptr;
|
||||||
|
|
||||||
// Host-side low-latency kernels' usages
|
|
||||||
volatile int* low_latency_usage_flag = nullptr;
|
|
||||||
int* low_latency_usage_flag_mapped = nullptr;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
|
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
|
||||||
|
|
||||||
@ -134,8 +130,6 @@ public:
|
|||||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||||
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||||
|
|
||||||
uint64_t get_low_latency_usage_flag() const;
|
|
||||||
|
|
||||||
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
|
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
|
||||||
|
|
||||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||||
|
@ -147,9 +147,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int num_device_sms,
|
||||||
int num_device_sms, cudaStream_t stream,
|
cudaStream_t stream, int phases);
|
||||||
int phases);
|
|
||||||
|
|
||||||
void combine(void* combined_x,
|
void combine(void* combined_x,
|
||||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||||
@ -158,9 +157,8 @@ void combine(void* combined_x,
|
|||||||
int* next_clean, int num_next_clean_int,
|
int* next_clean, int num_next_clean_int,
|
||||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int num_device_sms,
|
||||||
int num_device_sms, cudaStream_t stream,
|
cudaStream_t stream, int phases, bool zero_copy);
|
||||||
int phases, bool zero_copy);
|
|
||||||
|
|
||||||
} // namespace internode_ll
|
} // namespace internode_ll
|
||||||
|
|
||||||
|
@ -48,9 +48,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
int* next_clean, int num_next_clean_int,
|
int* next_clean, int num_next_clean_int,
|
||||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
bool round_scale, int* usage_flag,
|
|
||||||
int num_warp_groups, int num_warps_per_group,
|
int num_warp_groups, int num_warps_per_group,
|
||||||
int phases) {
|
bool round_scale, int phases) {
|
||||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||||
@ -189,10 +188,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = lane_id; i < num_experts; i += 32)
|
for (int i = lane_id; i < num_experts; i += 32)
|
||||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
|
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
|
||||||
} else if (sm_id == 1) {
|
|
||||||
// The second SM is also responsible for notifying PCIe usage
|
|
||||||
if (lane_id == 0)
|
|
||||||
atomicAdd_system(usage_flag, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This SM should be responsible for some destination experts, read `topk_idx` for them
|
// This SM should be responsible for some destination experts, read `topk_idx` for them
|
||||||
@ -341,9 +336,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int num_device_sms,
|
||||||
int num_device_sms, cudaStream_t stream,
|
cudaStream_t stream, int phases) {
|
||||||
int phases) {
|
|
||||||
constexpr int kNumMaxTopK = 9;
|
constexpr int kNumMaxTopK = 9;
|
||||||
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
|
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
|
||||||
const int num_warps_per_group = 32 / num_warp_groups;
|
const int num_warps_per_group = 32 / num_warp_groups;
|
||||||
@ -380,9 +374,8 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
|||||||
next_clean, num_next_clean_int, \
|
next_clean, num_next_clean_int, \
|
||||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||||
num_topk, num_experts, rank, num_ranks, \
|
num_topk, num_experts, rank, num_ranks, \
|
||||||
round_scale, usage_flag, \
|
|
||||||
num_warp_groups, num_warps_per_group, \
|
num_warp_groups, num_warps_per_group, \
|
||||||
phases); } break
|
round_scale, phases); } break
|
||||||
|
|
||||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||||
@ -400,7 +393,6 @@ combine(void* combined_x,
|
|||||||
int num_combined_tokens, int hidden, int num_topk,
|
int num_combined_tokens, int hidden, int num_topk,
|
||||||
int num_max_dispatch_tokens_per_rank,
|
int num_max_dispatch_tokens_per_rank,
|
||||||
int num_experts, int rank, int num_ranks,
|
int num_experts, int rank, int num_ranks,
|
||||||
int* usage_flag,
|
|
||||||
int num_warp_groups, int num_warps_per_group,
|
int num_warp_groups, int num_warps_per_group,
|
||||||
int phases, bool zero_copy) {
|
int phases, bool zero_copy) {
|
||||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||||
@ -497,13 +489,11 @@ combine(void* combined_x,
|
|||||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// Wait all ranks to arrive and notify usages
|
// Wait all ranks to arrive
|
||||||
if (responsible_expert_idx < num_experts) {
|
if (responsible_expert_idx < num_experts) {
|
||||||
EP_DEVICE_ASSERT(num_warps_per_group > 1);
|
EP_DEVICE_ASSERT(num_warps_per_group > 1);
|
||||||
if (sub_warp_id == 0 and lane_id == 0) {
|
if (sub_warp_id == 0 and lane_id == 0) {
|
||||||
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||||
} else if (sm_id == 0 and sub_warp_id == 1 and lane_id == 0) {
|
|
||||||
atomicAdd_system(usage_flag, 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cg::this_grid().sync();
|
cg::this_grid().sync();
|
||||||
@ -555,9 +545,8 @@ void combine(void* combined_x,
|
|||||||
int* next_clean, int num_next_clean_int,
|
int* next_clean, int num_next_clean_int,
|
||||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int num_device_sms,
|
||||||
int num_device_sms, cudaStream_t stream,
|
cudaStream_t stream, int phases, bool zero_copy) {
|
||||||
int phases, bool zero_copy) {
|
|
||||||
constexpr int kNumMaxTopk = 9;
|
constexpr int kNumMaxTopk = 9;
|
||||||
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
|
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
|
||||||
const int num_warps_per_group = 32 / num_warp_groups;
|
const int num_warps_per_group = 32 / num_warp_groups;
|
||||||
@ -582,7 +571,6 @@ LAUNCH_KERNEL(&cfg, combine_func, \
|
|||||||
num_combined_tokens, hidden, num_topk, \
|
num_combined_tokens, hidden, num_topk, \
|
||||||
num_max_dispatch_tokens_per_rank, \
|
num_max_dispatch_tokens_per_rank, \
|
||||||
num_experts, rank, num_ranks, \
|
num_experts, rank, num_ranks, \
|
||||||
usage_flag, \
|
|
||||||
num_warp_groups, num_warps_per_group, \
|
num_warp_groups, num_warps_per_group, \
|
||||||
phases, zero_copy); } break
|
phases, zero_copy); } break
|
||||||
|
|
||||||
|
@ -457,19 +457,6 @@ class Buffer:
|
|||||||
async_finish, allocate_on_comm_stream)
|
async_finish, allocate_on_comm_stream)
|
||||||
return combined_x, combined_topk_weights, EventOverlap(event)
|
return combined_x, combined_topk_weights, EventOverlap(event)
|
||||||
|
|
||||||
def get_low_latency_usage_flag(self):
|
|
||||||
"""
|
|
||||||
Return a host-side integer flag, which indicates the stages of low-latency kernels.
|
|
||||||
The initial value is 0, the low-latency dispatch will add 1 before communication, the low-latency combine
|
|
||||||
will add 1 after communication.
|
|
||||||
This is useful when there is no two-batch overlap, and you want to overlap H2D/D2H transfer with attention layers.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
flag: the host-side integer flag pointer. The value is in `int`, but returns a `uint64_t` pointer. Please
|
|
||||||
`reinterpret_cast` the returned value into `int*`.
|
|
||||||
"""
|
|
||||||
return self.runtime.get_low_latency_usage_flag()
|
|
||||||
|
|
||||||
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
|
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
|
||||||
"""
|
"""
|
||||||
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
|
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
|
||||||
|
@ -166,7 +166,6 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
|||||||
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
|
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
|
||||||
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
|
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
|
||||||
num_qps_per_rank=num_experts // num_ranks)
|
num_qps_per_rank=num_experts // num_ranks)
|
||||||
buffer.get_low_latency_usage_flag()
|
|
||||||
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
|
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||||
|
|
||||||
do_pressure_test = False
|
do_pressure_test = False
|
||||||
|
Loading…
Reference in New Issue
Block a user