diff --git a/README.md b/README.md index d1b255c..ccc46e0 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,11 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## Roadmap -- [ ] More correctness tests for grouped-contiguous layout +- [x] More correctness tests for grouped-contiguous layout - [x] Shared memory swizzling for output - [ ] Larger block size on N (up to 256) -- [ ] MoE scheduler with TMA multicast compatibility +- [x] MoE scheduler with TMA multicast compatibility +- [ ] Fix TMA multicast compatibility for indivisible shapes - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models - [ ] Utility kernels for MoE models (as a pre-built CUDA library) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index c000afa..c2934b8 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Prefetch TMA descriptors at very beginning if (threadIdx.x == kNumMathThreads) { - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_scales_a); // `tensor_map_d` is only used in swizzling mode // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode if constexpr (kSwizzleDMode > 0) - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + cute::prefetch_tma_descriptor(&tensor_map_d); } __syncwarp(); @@ -212,8 +212,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) { + // NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B), + // requiring additional checks before multicast operations. + DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast"); + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + } else { + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + } full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 6e3cb52..95dcd33 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -43,6 +43,18 @@ struct Scheduler { } } + __device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) { + if constexpr (kGemmType == GemmType::Normal) { + return true; + } else if constexpr (kGemmType == GemmType::GroupedContiguous) { + auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } else if constexpr (kGemmType == GemmType::GroupedMasked) { + return false; + } + } + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); @@ -72,10 +84,10 @@ struct Scheduler { const uint32_t& block_idx, const uint32_t& m_block_idx=0) { if constexpr (kGemmType == GemmType::Normal) { return block_idx * block_size; - } else if (kGemmType == GemmType::GroupedContiguous) { + } else if constexpr (kGemmType == GemmType::GroupedContiguous) { auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); return offset * shape_dim + block_idx * block_size; - } else if (kGemmType == GemmType::GroupedMasked) { + } else if constexpr (kGemmType == GemmType::GroupedMasked) { return curr_group_idx * shape_dim + block_idx * block_size; } } diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index eab5442..f52dc48 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -146,10 +146,9 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, best_tma_multicast_config = (1, True) # Try to multicast on the larger block side first - is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked) is_multicast_legal = { 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm, + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked), } for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): if m >= 512 and is_multicast_legal[i]: diff --git a/tests/test_core.py b/tests/test_core.py index 0f3c16d..bdc1841 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,6 +4,7 @@ from typing import Tuple import deep_gemm from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor +from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -38,7 +39,38 @@ def construct(m: int, k: int, n: int) -> \ return x_fp8, y_fp8, out, ref_out -def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) -> \ +def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + m = 0 + m_aligned = get_m_alignment_for_contiguous_layout() + group_m_list = [] + for i in range(num_groups): + group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned) + m += group_m + group_m_list.append(group_m) + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, group_m in enumerate(group_m_list): + end = start + group_m + m_indices[start:end] = i + ref_out[start:end] = x[start:end] @ y[i].t() + start = end + + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + return m, x_fp8, y_fp8, m_indices, out, ref_out + + +def construct_masked_grouped(num_groups: int, m: int, k: int, n: int) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) @@ -52,11 +84,6 @@ def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - # For non-masked input, we must merge the group and M dims - if not is_masked: - x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) - out, ref_out = out.view(-1, n), ref_out.view(-1, n) - # Transpose earlier so that the testing will not trigger transposing kernels x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) return x_fp8, y_fp8, out, ref_out @@ -88,28 +115,24 @@ def test_gemm() -> None: def test_m_grouped_gemm_contiguous() -> None: print('Testing grouped contiguous GEMM:') - for num_groups, m, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): + for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), (8, 4096, 7168, 4096), (8, 4096, 2048, 7168)): # TODO: make a stronger test - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) - m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) diff = calc_diff(out, ref_out) - assert diff < 0.001, f'm={m * num_groups}, {k=}, {n=}, {diff:.5f}' + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) - m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) # noinspection PyShadowingNames def test_func(): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') + print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') print() @@ -121,7 +144,7 @@ def test_m_grouped_gemm_masked() -> None: # Test correctness masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) for i in range(10): - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) + x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): masked_m[j] = random.choice(masked_m_candidates) @@ -132,7 +155,7 @@ def test_m_grouped_gemm_masked() -> None: assert diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) + x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m # noinspection PyShadowingNames