From 6366f5ad1a88450f7dbee87223b130bbb330b9a2 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 14 Apr 2025 10:26:33 +0800 Subject: [PATCH] Optimize expression --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index b4bce1c..ced5e03 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -386,7 +386,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Reshape the atom in another view and swizzle // - original: `(WGMMA_M_PER_WARP, kSwizzleDMode / kNumBankGroupBytes)` // - new: `(WGMMA_M_PER_WARP * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` - auto row = bank_group_index / 8, col = bank_group_index % 8; + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); col ^= row % (kSwizzleDMode / 16); // Add back into the base pointer