For the SMs which calculate metadata in notify_dispatch, each warp in the SM is used to calculate the metadata of one channel. The default configuration is 8 warps for 10 channels, which needs two rounds of loop. Maybe the number of warps can be configured to the number of the channels so that one loop is enough.

This commit is contained in:
songhexiang 2025-03-28 06:43:29 +00:00
parent e130cc6e7d
commit 4dd1e68ac8

View File

@ -428,7 +428,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
buffer_ptrs, task_fifo_ptrs, head, rank, \
cpu_rdma_team); } break
constexpr int kNumThreads = 256;
constexpr int kNumThreads = std::max(256, 32 * num_channels);
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta