Use swizzling instead of padding (#86)

* Add swizzling params

* Add TMA D descriptor

* Always use STSMx2

* Swizzling draft

* Compatible with padding

* Fix bugs

* Optimize swizzle performance

* Optimize expression

* Optimize TMA issues

* Fix README

* Stricter assertions
This commit is contained in:
Chenggang Zhao
2025-04-14 15:20:58 +08:00
committed by GitHub
parent 2e7e58011b
commit 37aa127451
6 changed files with 147 additions and 104 deletions

View File

@@ -11,18 +11,20 @@ int main() {
constexpr int BLOCK_N = 128;
constexpr int BLOCK_K = 128;
constexpr int BLOCK_N_PADDING = 0;
constexpr int kSwizzleDMode = 0;
constexpr int kNumGroups = 1;
constexpr int kNumStages = 5;
constexpr int kNumTMAMulticast = 1;
constexpr bool kIsTMAMulticastOnA = false;
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0));
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast<nv_bfloat16*>(0), m);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast<float*>(0), m);
gemm_t::run(nullptr, nullptr, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
nullptr, 132, 0);
return 0;
}