Fix TMA store bugs and code format

This commit is contained in:
Chenggang Zhao 2025-02-27 17:57:21 +08:00
parent b05ed2f017
commit ca13ce0fab
2 changed files with 6 additions and 6 deletions

View File

@ -406,7 +406,8 @@ public:
template <typename T> template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor, return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N, shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
min(BLOCK_M, shape_m), BLOCK_N,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
} }

View File

@ -38,7 +38,7 @@ def extract_ffma(sass):
current = [] current = []
if os.getenv('DG_PRINT_REG_REUSE', None): if os.getenv('DG_PRINT_REG_REUSE', None):
print(f"Found {len(collected)} FFMA segments") print(f'Found {len(collected)} FFMA segments')
return collected return collected
@ -58,7 +58,6 @@ def validate(m, offset, le_bytes, num_lines):
def parse_registers(line): def parse_registers(line):
import re
line = re.sub(r'/\*.*?\*/', '', line) line = re.sub(r'/\*.*?\*/', '', line)
line = line.replace(';', '') line = line.replace(';', '')
tokens = line.strip().split(',') tokens = line.strip().split(',')
@ -92,7 +91,7 @@ def modify_segment(m, name, ffma_lines):
is_first_occurred = dst_reg not in dst_reg_set is_first_occurred = dst_reg not in dst_reg_set
if is_first_occurred or (last_reused and dst_reg == last_dst_reg): if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
# Modify the `reuse` and `yield` bits # Modify the `reuse` and `yield` bits
assert high_hex & 0x0800200000000000, f"{hex(high_hex)}" assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
high_hex ^= 0x0800200000000000 high_hex ^= 0x0800200000000000
reused = False reused = False
num_changed += 1 num_changed += 1
@ -102,7 +101,7 @@ def modify_segment(m, name, ffma_lines):
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg last_reused, last_dst_reg = reused, dst_reg
if os.getenv('DG_PRINT_REG_REUSE', None): if os.getenv('DG_PRINT_REG_REUSE', None):
print(f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}") print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
# Find the offset # Find the offset
offsets = [] offsets = []
@ -130,7 +129,7 @@ def process(path):
mm.close() mm.close()
if __name__ == "__main__": if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
parser.add_argument('--so', help='Path to the SO file') parser.add_argument('--so', help='Path to the SO file')
args = parser.parse_args() args = parser.parse_args()