tma support indivisible num_n_blocks/num_m_blocks

This commit is contained in:
yukuai
2025-04-22 14:35:31 +08:00
parent 891f35adf5
commit ee4204ad98
3 changed files with 64 additions and 37 deletions

View File

@@ -194,6 +194,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Assign TMA multicast number into A and B
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
@@ -202,19 +203,29 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
// Issue TMA A
// NOTES: There may be additional odd rows/columns or cases where multicast is not possible.
// In grouped contiguous GEMM, different m_block_idx values can also lead to the inability to multicast.
// We use is_tma_multicast_valid to determine whether multicast is possible.
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
if (kNumTMAMulticastOnA > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
}
else {
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
}
// Issue TMA B
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
// requiring additional checks before multicast operations.
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
@@ -281,7 +292,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
auto target_cta_idx = scheduler.is_tma_multicast_valid(m_block_idx) ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta_idx) : void();
}
};

View File

@@ -28,6 +28,7 @@ struct Scheduler {
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
int num_blocks_in_group;
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
int* grouped_layout = nullptr) {
@@ -43,15 +44,21 @@ struct Scheduler {
}
}
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
if constexpr (kGemmType == GemmType::Normal) {
return true;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
return true;
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
if constexpr (kIsTMAMulticastOnA) {
return true;
} else {
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
}
}
}
@@ -59,23 +66,30 @@ struct Scheduler {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
// TODO: unify these 2 branches
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
}
else {
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
first_block_idx += num_blocks_in_group ^ 1;
num_blocks_in_group = 1;
}
}
if constexpr (kIsTMAMulticastOnA) {
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
} else {
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
n_block_idx = in_group_idx / num_m_blocks_in_group;
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
}
else {
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
n_block_idx = in_group_idx / num_blocks_in_group;
}
}