mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
fix tma_d_offset_desc_swapAB, update unitest
This commit is contained in:
@@ -438,6 +438,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t NUM_WARPS_PER_BLOCK>
|
||||
static __device__ __forceinline__ void write_result_to_gmem(__nv_bfloat16* gmem_d_this_block,
|
||||
__nv_bfloat16 const* smem_d, uint32_t const m_offset, uint32_t const m_boundary, uint32_t const n_offset,
|
||||
@@ -638,18 +639,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx), kNumTMAMulticast);
|
||||
|
||||
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
|
||||
{
|
||||
tma_copy(&tensor_map_scales_a,
|
||||
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s],
|
||||
scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
|
||||
}
|
||||
else
|
||||
{
|
||||
tma_copy(&tensor_map_scales_a,
|
||||
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_scales_a_idx(k_idx / BLOCK_K), kNumTMAMulticast);
|
||||
}
|
||||
tma_copy(&tensor_map_scales_a,
|
||||
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s],
|
||||
scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
|
||||
|
||||
// Issue TMA B without broadcasting
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), smem_b[s], k_idx,
|
||||
@@ -826,45 +818,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
|
||||
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16);
|
||||
}
|
||||
|
||||
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
|
||||
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
|
||||
bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary;
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
if (!cross_boundary)
|
||||
{
|
||||
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
|
||||
bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary;
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
if (!cross_boundary)
|
||||
{
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx);
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
__nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N;
|
||||
constexpr int NUM_WARPS
|
||||
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
|
||||
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
|
||||
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
cute::SM90_TMA_STORE_2D::copy(
|
||||
&tensor_map_d, smem_d, n_block_idx * BLOCK_N, scheduler.get_global_m_idx(m_block_idx));
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx);
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
__nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N;
|
||||
constexpr int NUM_WARPS
|
||||
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
|
||||
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
|
||||
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
@@ -1050,18 +1025,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
|
||||
smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx), kNumTMAMulticast);
|
||||
|
||||
// Issue TMA scales_b (act scales) for B matrix
|
||||
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
|
||||
{
|
||||
tma_copy(&tensor_map_scales_b,
|
||||
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
|
||||
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
|
||||
}
|
||||
else
|
||||
{
|
||||
tma_copy(&tensor_map_scales_b,
|
||||
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N,
|
||||
scheduler.get_global_scales_b_idx(k_idx / BLOCK_K), kNumTMAMulticast);
|
||||
}
|
||||
tma_copy(&tensor_map_scales_b,
|
||||
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
|
||||
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
|
||||
|
||||
full_barrier.arrive_and_expect_tx(
|
||||
SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
|
||||
@@ -1246,45 +1212,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
|
||||
smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid);
|
||||
}
|
||||
|
||||
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
|
||||
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
|
||||
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
if (!cross_boundary)
|
||||
{
|
||||
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
|
||||
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
if (!cross_boundary)
|
||||
{
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
|
||||
constexpr int NUM_WARPS
|
||||
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
|
||||
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
|
||||
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
// Use TMA store to write back to global memory
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
cute::SM90_TMA_STORE_2D::copy(
|
||||
&tensor_map_d, smem_d, m_block_idx * BLOCK_M, scheduler.get_global_n_idx(n_block_idx));
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
|
||||
cute::tma_store_arrive();
|
||||
cute::tma_store_wait<0>();
|
||||
}
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
|
||||
constexpr int NUM_WARPS
|
||||
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
|
||||
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
|
||||
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,13 +67,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
|
||||
return smem_size, swizzle_mode, block_n_padding
|
||||
else:
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
# NOTES: `scales_b` in a total manner or per-stage manner
|
||||
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
|
||||
smem_d = block_m * block_n * (4 if is_fp32_out else 2)
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales
|
||||
smem_b_per_stage = block_n * block_k
|
||||
@@ -87,11 +82,12 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_scales_b
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_a_per_stage * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += ceil_div(smem_scales_a_per_stage, 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
|
||||
|
||||
# no swizzle, no block_n_padding
|
||||
swizzle_mode = 0
|
||||
block_n_padding = 0
|
||||
|
||||
return smem_size, swizzle_mode, block_n_padding
|
||||
|
||||
@@ -105,7 +101,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
else:
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
block_ns = tuple(range(16, 129, 8))
|
||||
|
||||
# Avoid bank conflicts for FP32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
@@ -221,20 +221,12 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups, n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
print("expected_m: ",expected_m)
|
||||
print("A shape: ",lhs.shape)
|
||||
print("A scale shape: ",lhs_scales.shape)
|
||||
print("B shape: ",rhs.shape)
|
||||
print("B scale shape: ",rhs_scales.shape)
|
||||
print("out shape: ",out.shape)
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
|
||||
max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4
|
||||
max_shape_m_32_align_padded = compute_padded_offset(m, num_groups)
|
||||
max_shape_m_32_align_padded = compute_padded_offset(max_shape_m_4_align, num_groups)
|
||||
|
||||
assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
|
||||
@@ -244,12 +236,14 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
|
||||
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
|
||||
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
|
||||
assert offsets.dtype == torch.int64
|
||||
assert out.dtype == torch.bfloat16
|
||||
assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
assert out.is_contiguous()
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
|
||||
#lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
assert rhs_scales.is_contiguous()
|
||||
|
||||
# Auto-tuning with compilation
|
||||
@@ -273,9 +267,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, m, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, m, n, n, block_m, block_n, num_groups, 0) # none swizzle
|
||||
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # none swizzle
|
||||
tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle
|
||||
|
||||
|
||||
@@ -287,7 +281,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'PROBLEM_OFFSETS': offsets,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m, 'N': n, 'K': k,
|
||||
'M': max_shape_m_4_align, 'N': n, 'K': k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
@@ -310,26 +304,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
else:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
|
||||
n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and n > block_m:
|
||||
assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||
|
||||
print("is_swap_ab=True =========")
|
||||
print("num_sms: ",num_sms)
|
||||
print("block_m: ",block_m)
|
||||
print("block_n: ",block_n)
|
||||
print("num_stages: ",num_stages)
|
||||
print("tma_multicast_config: ",tma_multicast_config)
|
||||
print("smem_config: ",smem_config)
|
||||
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups)
|
||||
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, m, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, n, m, m, block_m, block_n, num_groups, 0) # no swizzle
|
||||
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_n, block_k, num_groups)
|
||||
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # no swizzle
|
||||
tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle
|
||||
|
||||
kwargs = {
|
||||
@@ -340,7 +325,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'PROBLEM_OFFSETS': offsets,
|
||||
'NUM_TMA_THREADS': num_tma_threads,
|
||||
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
|
||||
'M': m, 'N': n, 'K': k,
|
||||
'M': max_shape_m_4_align, 'N': n, 'K': k,
|
||||
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
|
||||
'NUM_GROUPS': num_groups,
|
||||
'NUM_STAGES': num_stages,
|
||||
|
||||
@@ -173,8 +173,8 @@ def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
|
||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(t,
|
||||
shape_n, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride,
|
||||
min(block_n, shape_n), min(block_m, shape_m),
|
||||
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
min(block_m, shape_n), min(block_n, shape_m),
|
||||
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user