fix tma_d_offset_desc_swapAB, update unitest

This commit is contained in:
wangzhe_ant
2025-06-24 17:52:28 +08:00
parent 26a603f518
commit ccd63bb234
5 changed files with 348 additions and 183 deletions

View File

@@ -438,6 +438,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t NUM_WARPS_PER_BLOCK>
static __device__ __forceinline__ void write_result_to_gmem(__nv_bfloat16* gmem_d_this_block,
__nv_bfloat16 const* smem_d, uint32_t const m_offset, uint32_t const m_boundary, uint32_t const n_offset,
@@ -638,18 +639,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx), kNumTMAMulticast);
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
{
tma_copy(&tensor_map_scales_a,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s],
scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
}
else
{
tma_copy(&tensor_map_scales_a,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_scales_a_idx(k_idx / BLOCK_K), kNumTMAMulticast);
}
tma_copy(&tensor_map_scales_a,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s],
scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), smem_b[s], k_idx,
@@ -826,45 +818,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16);
}
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
}
}
else
{
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, smem_d, n_block_idx * BLOCK_N, scheduler.get_global_m_idx(m_block_idx));
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
}
__syncwarp();
}
}
@@ -1050,18 +1025,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx), kNumTMAMulticast);
// Issue TMA scales_b (act scales) for B matrix
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
{
tma_copy(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
}
else
{
tma_copy(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N,
scheduler.get_global_scales_b_idx(k_idx / BLOCK_K), kNumTMAMulticast);
}
tma_copy(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
full_barrier.arrive_and_expect_tx(
SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
@@ -1246,45 +1212,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid);
}
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
}
}
else
{
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, smem_d, m_block_idx * BLOCK_M, scheduler.get_global_n_idx(n_block_idx));
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
}
__syncwarp();
}
}

View File

@@ -67,13 +67,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
return smem_size, swizzle_mode, block_n_padding
else:
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_d = block_m * block_n * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales
smem_b_per_stage = block_n * block_k
@@ -87,11 +82,12 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_scales_b
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_a_per_stage * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += ceil_div(smem_scales_a_per_stage, 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
# no swizzle, no block_n_padding
swizzle_mode = 0
block_n_padding = 0
return smem_size, swizzle_mode, block_n_padding
@@ -105,7 +101,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
block_ns = tuple(range(16, 129, 8))
block_ns = tuple(range(16, 129, 8))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]

View File

@@ -221,20 +221,12 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups, n, k_ = rhs.shape
m_, n_ = out.shape
print("expected_m: ",expected_m)
print("A shape: ",lhs.shape)
print("A scale shape: ",lhs_scales.shape)
print("B shape: ",rhs.shape)
print("B scale shape: ",rhs_scales.shape)
print("out shape: ",out.shape)
# Type and shape checks
assert m == m_ and n == n_ and k == k_
max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4
max_shape_m_32_align_padded = compute_padded_offset(m, num_groups)
max_shape_m_32_align_padded = compute_padded_offset(max_shape_m_4_align, num_groups)
assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0
@@ -244,12 +236,14 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert offsets.dtype == torch.int64
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
#lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
@@ -273,9 +267,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, m, n, n, block_m, block_n, num_groups, 0) # none swizzle
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # none swizzle
tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle
@@ -287,7 +281,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
'PROBLEM_OFFSETS': offsets,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'M': max_shape_m_4_align, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
@@ -310,26 +304,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
else:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True)
# Extra checks for TMA store
if num_groups > 1 and n > block_m:
assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
print("is_swap_ab=True =========")
print("num_sms: ",num_sms)
print("block_m: ",block_m)
print("block_n: ",block_n)
print("num_stages: ",num_stages)
print("tma_multicast_config: ",tma_multicast_config)
print("smem_config: ",smem_config)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, m, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, n, m, m, block_m, block_n, num_groups, 0) # no swizzle
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # no swizzle
tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle
kwargs = {
@@ -340,7 +325,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
'PROBLEM_OFFSETS': offsets,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'M': max_shape_m_4_align, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,

View File

@@ -173,8 +173,8 @@ def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(t,
shape_n, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride,
min(block_n, shape_n), min(block_m, shape_m),
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
min(block_m, shape_n), min(block_n, shape_m),
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)