mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix bugs
This commit is contained in:
parent
901cdf79be
commit
fdb41efbd3
@ -470,7 +470,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Read RDMA rank existence
|
||||
uint64_t is_token_in_rank_uint64 = 0;
|
||||
if (lane_id < kNumRDMARanks) {
|
||||
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
|
||||
is_token_in_rank_uint64 = __ldg(reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS));
|
||||
global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
Loading…
Reference in New Issue
Block a user