mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-07 05:24:27 +00:00
Bugs fixed
This commit is contained in:
parent
592296cd45
commit
680e424bdc
@ -282,6 +282,7 @@ For two micro-batch overlapping, you can refer to the following figure. With our
|
|||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
|
- [ ] AR support (releasing soon)
|
||||||
- [ ] A100 support (intranode only)
|
- [ ] A100 support (intranode only)
|
||||||
- [ ] Support BF16 for the low-latency dispatch kernel
|
- [ ] Support BF16 for the low-latency dispatch kernel
|
||||||
- [ ] Support NVLink protocol for intranode low-latency kernels
|
- [ ] Support NVLink protocol for intranode low-latency kernels
|
||||||
|
@ -383,8 +383,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
|||||||
|
|
||||||
// Calculate prefix sum
|
// Calculate prefix sum
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
if (thread_id == 0) {
|
||||||
if (thread_id < kNumRDMARanks) {
|
|
||||||
auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
|
auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 1; i < num_channels; ++ i)
|
for (int i = 1; i < num_channels; ++ i)
|
||||||
|
Loading…
Reference in New Issue
Block a user