diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 0d51648..386139c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -271,81 +271,91 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, } }; - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // TODO: remove some useless computation for unaligned Ms + if (!scheduler.is_valid_m(math_wg_idx * WGMMA::M, m_block_idx)) { + // Skip useless computation for unaligned Ms + launch_k_iterations([&](int k_iter, auto type, auto _) { #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } + for (uint32_t s = 0; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); } - } + }, num_former_iters); + } else { + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }, num_former_iters); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }, num_former_iters); + } // TMA checks constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 38c5c49..8b6dbf3 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -48,6 +48,16 @@ struct Scheduler { } } + __device__ __forceinline__ bool is_valid_m(const uint32_t m_offset, const uint32_t& m_block_idx) const { + if constexpr (kGemmType == GemmType::Normal) { + return true; + } else if constexpr (kGemmType == GemmType::GroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) != -1; + } else if constexpr (kGemmType == GemmType::GroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx); + } + } + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { if (num_blocks_in_group == 1) return false; diff --git a/tests/test_core.py b/tests/test_core.py index 03038db..6fb1ad8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -53,8 +53,8 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: m_aligned = get_m_alignment_for_contiguous_layout() group_m_list = [] for i in range(num_groups): - group_m = m_aligned * random.randint(int(expected_m_per_group * 0.7) // m_aligned, int(expected_m_per_group * 1.3) // m_aligned) - m += group_m + group_m = random.randint(int(expected_m_per_group * 0.7), int(expected_m_per_group * 1.3)) + m += ceil_div(group_m, m_aligned) * m_aligned group_m_list.append(group_m) x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) @@ -64,10 +64,12 @@ def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: start = 0 for i, group_m in enumerate(group_m_list): - end = start + group_m - m_indices[start:end] = i - ref_out[start:end] = x[start:end] @ y[i].t() - start = end + actual_end = start + group_m + aligned_end = start + ceil_div(group_m, m_aligned) * m_aligned + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = -1 + ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() + start = aligned_end assert m % 4 == 0, f'TMA alignment error: {m}' x_fp8 = per_token_cast_to_fp8(x) @@ -191,6 +193,8 @@ def test_m_grouped_gemm_contiguous() -> None: # TODO: make a stronger test m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) + ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) diff = calc_diff(out, ref_out) assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' @@ -202,24 +206,25 @@ def test_m_grouped_gemm_contiguous() -> None: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + sum_m = (m_indices != -1).sum().item() print(f' > Performance ({num_groups=}, m={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + num_groups * k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + f'throughput: {2 * sum_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(sum_m * k + num_groups * k * n + sum_m * n * 2) / 1e9 / t:4.0f} GB/s') print() def test_m_grouped_gemm_masked() -> None: print('Testing grouped masked GEMM:') - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + m = 4096 + for num_groups, excepted_m in ((1, 1024), (2, 512), (4, 256)): for k, n in ((7168, 4096), (2048, 7168), ): # Test correctness - masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) for i in range(10): x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): - masked_m[j] = random.choice(masked_m_candidates) + masked_m[j] = random.randint(int(excepted_m * 0.7), int(excepted_m * 1.3)) expected_m = min(int(masked_m.float().mean()) + 1, m) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) for j in range(num_groups): @@ -228,17 +233,20 @@ def test_m_grouped_gemm_masked() -> None: # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) x_fp8, y_fp8, out, ref_out = construct_masked_grouped(num_groups, m, k, n) - masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m + for j in range(num_groups): + masked_m[j] = random.randint(int(excepted_m * 0.7), int(excepted_m * 1.3)) + expected_m = min(int(masked_m.float().mean()) + 1, m) + sum_m = masked_m.sum().item() # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) # Test performance with fixed shapes t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s') + f'throughput: {2 * sum_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(sum_m * k + num_groups * k * n + sum_m * n * 2) / 1e9 / t:4.0f} GB/s') print()