Update assertion of num_rc_per_pe.

This commit is contained in:
Shangyan Zhou 2025-06-13 15:16:23 +08:00
parent 05df5554ff
commit 483f00af84

View File

@ -357,14 +357,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
kNVLReceivers
};
const auto num_sms = static_cast<int>(gridDim.x);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
const auto num_channels = num_sms / 2, channel_id = sm_id / 2;
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels);
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms);
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) {