mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
tma support indivisible num_n_blocks/num_m_blocks
This commit is contained in:
@@ -194,6 +194,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// Assign TMA multicast number into A and B
|
||||
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
|
||||
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
@@ -202,19 +203,29 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
|
||||
|
||||
// Issue TMA A
|
||||
// NOTES: There may be additional odd rows/columns or cases where multicast is not possible.
|
||||
// In grouped contiguous GEMM, different m_block_idx values can also lead to the inability to multicast.
|
||||
// We use is_tma_multicast_valid to determine whether multicast is possible.
|
||||
// Issue TMA A
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
if (kNumTMAMulticastOnA > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
}
|
||||
else {
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
|
||||
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
|
||||
}
|
||||
|
||||
// Issue TMA B
|
||||
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
|
||||
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
|
||||
// requiring additional checks before multicast operations.
|
||||
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_valid(m_block_idx)) {
|
||||
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
|
||||
@@ -281,7 +292,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
|
||||
auto target_cta_idx = scheduler.is_tma_multicast_valid(m_block_idx) ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta_idx) : void();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ struct Scheduler {
|
||||
int* grouped_layout;
|
||||
// Only used for masked layout
|
||||
uint32_t curr_group_idx, curr_cumsum;
|
||||
int num_blocks_in_group;
|
||||
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
|
||||
int* grouped_layout = nullptr) {
|
||||
@@ -43,15 +44,21 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return true;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
|
||||
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return group_idx == peer_group_idx;
|
||||
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
|
||||
return true;
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
|
||||
return true;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
return true;
|
||||
} else {
|
||||
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
|
||||
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
|
||||
return group_idx == peer_group_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,23 +66,30 @@ struct Scheduler {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
// TODO: unify these 2 branches
|
||||
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
|
||||
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
|
||||
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
|
||||
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
|
||||
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
|
||||
num_blocks_in_group = num_blocks_in_group ^ 1;
|
||||
}
|
||||
else {
|
||||
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
|
||||
first_block_idx += num_blocks_in_group ^ 1;
|
||||
num_blocks_in_group = 1;
|
||||
}
|
||||
}
|
||||
if constexpr (kIsTMAMulticastOnA) {
|
||||
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = in_group_idx / num_n_blocks_in_group;
|
||||
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
|
||||
} else {
|
||||
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
|
||||
auto group_idx = block_idx / num_blocks_per_group;
|
||||
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
|
||||
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
|
||||
auto in_group_idx = block_idx % num_blocks_per_group;
|
||||
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_m_blocks_in_group;
|
||||
m_block_idx = in_group_idx / num_blocks_in_group;
|
||||
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
}
|
||||
else {
|
||||
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
n_block_idx = in_group_idx / num_blocks_in_group;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr,
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
def is_tma_multicast_legal(shape_dim: int, multicast_block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||
return shape_dim % multicast_block_dim == 0 and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_swizzle_mode(block_n: int) -> int:
|
||||
@@ -146,8 +146,9 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
# NOTES: Currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the n-direction to be even.
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||
'A': is_tma_multicast_legal(n, best_block_n * (2 if is_grouped_masked else 1), 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
|
||||
Reference in New Issue
Block a user