fix tma_d_offset_desc_swapAB, update unitest

This commit is contained in:
wangzhe_ant 2025-06-24 17:52:28 +08:00
parent 26a603f518
commit ccd63bb234
5 changed files with 348 additions and 183 deletions

View File

@ -438,6 +438,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t NUM_WARPS_PER_BLOCK>
static __device__ __forceinline__ void write_result_to_gmem(__nv_bfloat16* gmem_d_this_block,
__nv_bfloat16 const* smem_d, uint32_t const m_offset, uint32_t const m_boundary, uint32_t const n_offset,
@ -638,18 +639,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx), kNumTMAMulticast);
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
{
tma_copy(&tensor_map_scales_a,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s],
scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
}
else
{
tma_copy(&tensor_map_scales_a,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_scales_a_idx(k_idx / BLOCK_K), kNumTMAMulticast);
}
tma_copy(&tensor_map_scales_a,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_a[s],
scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
// Issue TMA B without broadcasting
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), smem_b[s], k_idx,
@ -826,45 +818,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16);
}
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
auto m_global_idx = scheduler.get_global_m_idx(m_block_idx);
bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
}
}
else
{
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, smem_d, n_block_idx * BLOCK_N, scheduler.get_global_m_idx(m_block_idx));
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_M, BLOCK_N, NUM_WARPS>(gmem_d_this_block, smem_d, m_global_idx,
scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N);
}
__syncwarp();
}
}
@ -1050,18 +1025,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx), kNumTMAMulticast);
// Issue TMA scales_b (act scales) for B matrix
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
{
tma_copy(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
}
else
{
tma_copy(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N,
scheduler.get_global_scales_b_idx(k_idx / BLOCK_K), kNumTMAMulticast);
}
tma_copy(&tensor_map_scales_b,
reinterpret_cast<uint64_t*>(&full_barrier), smem_scales_b[s],
scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast);
full_barrier.arrive_and_expect_tx(
SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
@ -1246,45 +1212,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMat
smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid);
}
if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset)
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
auto n_global_idx = scheduler.get_global_n_idx(n_block_idx);
bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary;
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
if (!cross_boundary)
{
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
}
}
else
{
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0)
{
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, smem_d, m_block_idx * BLOCK_M, scheduler.get_global_n_idx(n_block_idx));
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx);
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
else
{
__nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M;
constexpr int NUM_WARPS
= (get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M) - 128) / 32;
write_result_to_gmem<BLOCK_N, BLOCK_M, NUM_WARPS>(gmem_d_this_block, smem_d, n_global_idx,
scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M);
}
__syncwarp();
}
}

View File

@ -67,13 +67,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
return smem_size, swizzle_mode, block_n_padding
else:
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_d = block_m * block_n * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales
smem_b_per_stage = block_n * block_k
@ -87,11 +82,12 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_scales_b
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_a_per_stage * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += ceil_div(smem_scales_a_per_stage, 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
# no swizzle, no block_n_padding
swizzle_mode = 0
block_n_padding = 0
return smem_size, swizzle_mode, block_n_padding
@ -105,7 +101,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
block_ns = tuple(range(16, 129, 8))
block_ns = tuple(range(16, 129, 8))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]

View File

@ -221,20 +221,12 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups, n, k_ = rhs.shape
m_, n_ = out.shape
print("expected_m: ",expected_m)
print("A shape: ",lhs.shape)
print("A scale shape: ",lhs_scales.shape)
print("B shape: ",rhs.shape)
print("B scale shape: ",rhs_scales.shape)
print("out shape: ",out.shape)
# Type and shape checks
assert m == m_ and n == n_ and k == k_
max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4
max_shape_m_32_align_padded = compute_padded_offset(m, num_groups)
max_shape_m_32_align_padded = compute_padded_offset(max_shape_m_4_align, num_groups)
assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0
@ -244,12 +236,14 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert offsets.dtype == torch.int64
assert out.dtype == torch.bfloat16
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
#lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
@ -273,9 +267,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, m, n, n, block_m, block_n, num_groups, 0) # none swizzle
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # none swizzle
tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle
@ -287,7 +281,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
'PROBLEM_OFFSETS': offsets,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'M': max_shape_m_4_align, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
@ -310,26 +304,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
else:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True)
# Extra checks for TMA store
if num_groups > 1 and n > block_m:
assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
print("is_swap_ab=True =========")
print("num_sms: ",num_sms)
print("block_m: ",block_m)
print("block_n: ",block_n)
print("num_stages: ",num_stages)
print("tma_multicast_config: ",tma_multicast_config)
print("smem_config: ",smem_config)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, m, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, n, m, m, block_m, block_n, num_groups, 0) # no swizzle
tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # no swizzle
tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle
kwargs = {
@ -340,7 +325,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
'PROBLEM_OFFSETS': offsets,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'M': max_shape_m_4_align, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,

View File

@ -173,8 +173,8 @@ def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor,
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(t,
shape_n, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride,
min(block_n, shape_n), min(block_m, shape_m),
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
min(block_m, shape_n), min(block_n, shape_m),
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)

View File

@ -6,6 +6,7 @@ print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
import random
import torch
from typing import List, Tuple
import itertools
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
@ -34,6 +35,49 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def construct(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
alignment = get_m_alignment_for_contiguous_layout()
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
start = 0
for i, group_m in enumerate(group_ms):
actual_end = start + group_m
aligned_end = start + ceil_div(group_m, alignment) * alignment
m_indices[start:actual_end] = i
m_indices[actual_end:aligned_end] = -1
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
start = aligned_end
ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out)
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
return m, x_fp8, y_fp8, m_indices, out, ref_out
def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
@ -55,10 +99,195 @@ def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group:
# Construct mask
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(1, 1))
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
assert masked_m.amax().item() <= max_m
return x_fp8, y_fp8, masked_m, out, ref_out
def construct_wgrad(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10
out = residual.clone()
ref_out = residual + (x.float() @ y.float().t())
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = per_token_cast_to_fp8(y)
# NOTES: please do inplace add on the `out` later
return x_fp8, y_fp8, residual, out, ref_out
def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]:
num_groups, total_k = len(k_sizes), sum(k_sizes)
x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16)
y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16)
out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
# Fill tensors with data and compute reference output
x_offset, y_offset = 0, 0
for idx, k in enumerate(k_sizes):
x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten())
y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten())
ref_out[idx] = x_chunk.float() @ y_chunk.float().t()
x_offset += m * k
y_offset += n * k
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
# Cast to FP8 and prepare scale factors
x_offset, y_offset, scale_offset = 0, 0, 0
for k in k_sizes:
x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k))
y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k))
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
num_scales = ceil_div(k, 128)
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
x_offset += m * k
y_offset += n * k
scale_offset += num_scales
return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes
def change_to_offset_layout(
ms: List[int],
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
x_list = []
x_scale_list = []
shape_m_total = 0
num_problems = len(ms)
m_acc = [0] + list(itertools.accumulate(ms))
# Need to keep the same as the one in cpp/include/tensorrt_llm/deep_gemm/scheduler.cuh
def compute_padded_offset(offset, idx_problem, alignment=32):
return (offset + idx_problem * (alignment - 1)) // alignment * alignment
offset = 0
for i in range(num_problems):
ms[i]
x_list.append(x_fp8[m_acc[i]:m_acc[i + 1]])
offset_next = compute_padded_offset(m_acc[i + 1], i + 1)
size_padded = (offset_next - offset) - (m_acc[i + 1] - m_acc[i])
x_scale_padded = torch.cat([
x_scale[m_acc[i]:m_acc[i + 1]],
torch.zeros(
[size_padded, *x_scale.shape[1:]],
dtype=x_scale.dtype,
device=x_scale.device,
),
])
x_scale_list.append(x_scale_padded)
offset = offset_next
shape_m_total = m_acc[-1]
ret_x = torch.cat(x_list)
ret_x_scale = torch.cat(x_scale_list)
ret_x_scale = ret_x_scale.t().contiguous()
pad_target = compute_padded_offset(shape_m_total, num_problems)
pad_target -= ret_x_scale.shape[1]
ret_x_scale = torch.nn.functional.pad(ret_x_scale, (0, pad_target),
mode='constant',
value=0)
return ret_x, ret_x_scale
def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
alignment = 4
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
offsets = torch.empty(num_groups+1, device='cuda', dtype=torch.int64)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
start = 0
offsets[0] = 0
for i, group_m in enumerate(group_ms):
aligned_end = start + ceil_div(group_m, alignment) * alignment
offsets[i+1] = aligned_end
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
start = aligned_end
group_ms[i] = ceil_div(group_m, alignment) * alignment
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
return group_ms, m, x_fp8, y_fp8, offsets.type(torch.int64), out, ref_out
def test_gemm() -> None:
print('Testing GEMM:')
for m in (64, 128, 4096):
for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing grouped contiguous GEMM:')
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168),
(8, 4096, 7168, 4096), (8, 4096, 2048, 7168),
(32, 256, 7168, 4096), (32, 256, 2048, 7168)):
# NOTES: we should mask the unfilled part before calculating difference
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
valid_m = (m_indices != -1).sum().item()
print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_masked() -> None:
print('Testing grouped masked GEMM:')
@ -86,88 +315,87 @@ def test_m_grouped_gemm_masked() -> None:
print()
def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
alignment = 32
group_ms = [int(expected_m_per_group * random.uniform(1, 1)) for _ in range(num_groups)]
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
def test_wgrad_gemm():
print('Testing weight gradient GEMM:')
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
offsets = torch.empty(num_groups+1, device='cuda', dtype=torch.int32)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
for k in (4096, 8192):
for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)):
# Test correctness
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
start = 0
offsets[0] = 0
for i, group_m in enumerate(group_ms):
aligned_end = start + ceil_div(group_m, alignment) * alignment
offsets[i+1] = aligned_end
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
start = aligned_end
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
# noinspection PyShadowingNames
def test_func():
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
return m, x_fp8, y_fp8, offsets, out, ref_out
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_k_grouped_wgrad_gemm():
print('Testing grouped weight gradient GEMM:')
for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)):
for m, n in ((7168, 4096), (2048, 7168)):
# Vary k sizes around base_k
k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)]
k_sizes.append(base_k * num_groups - sum(k_sizes))
# Test correctness
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
for idx in range(num_groups):
diff = calc_diff(out[idx], ref_out[idx])
assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}'
# Construct new tensors to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
total_k = sum(k_sizes)
def test_func():
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups
print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, '
f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_offset() -> None:
print('Testing grouped contiguous GEMM:')
for num_groups, expected_m_per_group, k, n in ((9, 32, 7168, 4096),):
for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096)):
# NOTES: we should mask the unfilled part before calculating difference
x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group)
ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n)
pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1])
for j in range(num_groups):
diff = calc_diff(out_mask[j, :masked_m_mask[j].item()], ref_out_mask[j, :masked_m_mask[j].item()])
#assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m_mask[j]}, {num_groups=}, {diff:.5f}'
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
diff = calc_diff(out_offset, ref_out_offset)
assert diff < 0.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group)
# Test performance with fixed shapes
# noinspection PyUnboundLocalVariable
valid_m = masked_m_mask.sum().item()
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_masked: Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
'''
m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n)
#deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(x_fp8_offset, y_fp8_offset, offset, out_offset, expected_m_per_group)
#diff = calc_diff(out_offset, ref_out_offset)
# assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(x_fp8_offset, y_fp8_offset, offset, out_offset, expected_m_per_group)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
valid_m = m_offset
print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
'''
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@ -177,4 +405,10 @@ if __name__ == '__main__':
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_m_grouped_gemm_offset()
test_wgrad_gemm()
test_k_grouped_wgrad_gemm()