diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 065bd88..0611c5c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,6 +27,17 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } +template +__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { + if (num_former_iters == kNumFormerIters) { + inner_launch_k_iterations(func, cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); +} + template ([](const auto& func, auto num_former_iters_type) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); + } + }, func, num_former_iters); }; // Register reconfigurations @@ -156,7 +171,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto _) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -193,7 +208,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); full_barriers[s]->arrive(); } - }); + }, 0); } // To safely deconstruct distributed shared barriers, we need another round of empty waits @@ -246,7 +261,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, }; // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -292,13 +307,21 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, float scale_0_1, scale_1_1; if constexpr (not kMustUseUniformedScaleB) scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + constexpr int kNumFormerIters = kMustUseUniformedScaleB ? WGMMA::kNumAccum / 4 : decltype(num_former_iters_type)::value; #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + for (int i = 0; i < kNumFormerIters; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = kNumFormerIters; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; } } @@ -308,7 +331,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }); + }, num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 0f099d8..aad8939 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime: '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases '--diag-suppress=177,174,940'] - cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi'] + cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] include_dirs = [get_jit_include_dir()]