mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Several code lints
This commit is contained in:
@@ -413,7 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
nvshmemx_signal_op(reinterpret_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
nvshmemx_signal_op(static_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
} else {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
|
||||
|
||||
@@ -573,10 +573,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Issue RDMA for non-local ranks
|
||||
if (dst_rdma_rank != rdma_rank and lane_id == 0) {
|
||||
nvshmemi_ibgda_put_nbi_thread(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
|
||||
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
|
||||
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
|
||||
channel_id, false);
|
||||
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
|
||||
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
|
||||
channel_id, false);
|
||||
}
|
||||
}
|
||||
sync_rdma_sender_smem();
|
||||
@@ -724,9 +724,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
if (lane_id == dst_rdma_rank)
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
}
|
||||
} else {
|
||||
// Lighter fence for local RDMA rank
|
||||
memory_fence();
|
||||
@@ -1573,9 +1574,11 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
if (lane_id == 0)
|
||||
if (lane_id == 0) {
|
||||
// TODO: use the full warp to do this
|
||||
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
}
|
||||
} else {
|
||||
memory_fence();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user