Support more shapes

This commit is contained in:
Chenggang Zhao
2025-02-28 10:04:59 +08:00
parent b69f630b91
commit 6c5da03ba9
2 changed files with 4 additions and 5 deletions

View File

@@ -50,7 +50,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += smem_scales_b * (1 if block_k % block_n == 0 else 2)
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
return smem_size