From ca13ce0fab704c5bf9a4fc19547402b2d10e2f18 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 27 Feb 2025 17:57:21 +0800 Subject: [PATCH] Fix TMA store bugs and code format --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 3 ++- deep_gemm/jit/interleave_ffma.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 711649c..f9df785 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -406,7 +406,8 @@ public: template static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N, + shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, + min(BLOCK_M, shape_m), BLOCK_N, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); } diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py index d6b7fd5..74e8358 100644 --- a/deep_gemm/jit/interleave_ffma.py +++ b/deep_gemm/jit/interleave_ffma.py @@ -38,7 +38,7 @@ def extract_ffma(sass): current = [] if os.getenv('DG_PRINT_REG_REUSE', None): - print(f"Found {len(collected)} FFMA segments") + print(f'Found {len(collected)} FFMA segments') return collected @@ -58,7 +58,6 @@ def validate(m, offset, le_bytes, num_lines): def parse_registers(line): - import re line = re.sub(r'/\*.*?\*/', '', line) line = line.replace(';', '') tokens = line.strip().split(',') @@ -92,7 +91,7 @@ def modify_segment(m, name, ffma_lines): is_first_occurred = dst_reg not in dst_reg_set if is_first_occurred or (last_reused and dst_reg == last_dst_reg): # Modify the `reuse` and `yield` bits - assert high_hex & 0x0800200000000000, f"{hex(high_hex)}" + assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' high_hex ^= 0x0800200000000000 reused = False num_changed += 1 @@ -102,7 +101,7 @@ def modify_segment(m, name, ffma_lines): new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) last_reused, last_dst_reg = reused, dst_reg if os.getenv('DG_PRINT_REG_REUSE', None): - print(f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}") + print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') # Find the offset offsets = [] @@ -130,7 +129,7 @@ def process(path): mm.close() -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') parser.add_argument('--so', help='Path to the SO file') args = parser.parse_args()