mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-03-31 07:05:05 +00:00
138 lines
4.4 KiB
Python
138 lines
4.4 KiB
Python
import argparse
|
|
import mmap
|
|
import os
|
|
import re
|
|
import subprocess
|
|
from torch.utils.cpp_extension import CUDA_HOME
|
|
|
|
|
|
def run_cuobjdump(file_path):
|
|
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
|
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
assert result.returncode == 0
|
|
return result.stdout
|
|
|
|
|
|
def extract_ffma(sass):
|
|
lines = sass.splitlines()
|
|
collected = []
|
|
current = []
|
|
|
|
arch_name, func_name = 'N/A', 'N/A'
|
|
skip_next_line = False
|
|
for line in lines:
|
|
if 'code for' in line:
|
|
arch_name = line.lstrip().lstrip('code for ').rstrip()
|
|
elif 'Function :' in line:
|
|
func_name = line.lstrip().lstrip('Function :').rstrip()
|
|
elif 'FFMA' in line:
|
|
current.append(line)
|
|
skip_next_line = True
|
|
elif skip_next_line:
|
|
current.append(line)
|
|
skip_next_line = False
|
|
else:
|
|
if len(current) >= 16:
|
|
assert len(current) % 2 == 0
|
|
collected.append((f'{arch_name}::{func_name}', current))
|
|
current = []
|
|
|
|
if os.getenv('DG_PRINT_REG_REUSE', None):
|
|
print(f'Found {len(collected)} FFMA segments')
|
|
return collected
|
|
|
|
|
|
def extract_hex_from_line(line):
|
|
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
|
|
assert match
|
|
return int(match.group(1), 16)
|
|
|
|
|
|
def validate(m, offset, le_bytes, num_lines):
|
|
assert len(le_bytes) == num_lines // 2
|
|
assert m[offset:offset + 16] == le_bytes[0]
|
|
for i in range(1, num_lines // 2):
|
|
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
|
|
return False
|
|
return True
|
|
|
|
|
|
def parse_registers(line):
|
|
line = re.sub(r'/\*.*?\*/', '', line)
|
|
line = line.replace(';', '')
|
|
tokens = line.strip().split(',')
|
|
registers = []
|
|
for token in tokens:
|
|
token = token.strip()
|
|
words = token.split()
|
|
for word in words:
|
|
if word.startswith('R'):
|
|
reg = word.split('.')[0]
|
|
registers.append(reg)
|
|
return registers
|
|
|
|
|
|
def modify_segment(m, name, ffma_lines):
|
|
num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
|
|
assert num_lines % 2 == 0
|
|
|
|
le_bytes, new_le_bytes = [], []
|
|
reused_list = []
|
|
dst_reg_set = set()
|
|
last_reused, last_dst_reg = False, ''
|
|
num_changed = 0
|
|
for i in range(num_lines // 2):
|
|
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
|
|
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
|
|
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
|
|
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
|
reused = (high_hex & 0x0800000000000000) != 0
|
|
if reused:
|
|
is_first_occurred = dst_reg not in dst_reg_set
|
|
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
|
# Modify the `reuse` and `yield` bits
|
|
assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
|
|
high_hex ^= 0x0800200000000000
|
|
reused = False
|
|
num_changed += 1
|
|
else:
|
|
reused_list.append(i)
|
|
dst_reg_set.add(dst_reg)
|
|
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
|
last_reused, last_dst_reg = reused, dst_reg
|
|
if os.getenv('DG_PRINT_REG_REUSE', None):
|
|
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
|
|
|
|
# Find the offset
|
|
offsets = []
|
|
offset = m.find(le_bytes[0])
|
|
while offset != -1:
|
|
offsets.append(offset)
|
|
offset = m.find(le_bytes[0], offset + 1)
|
|
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
|
|
|
|
# Replace with `new_le_bytes`
|
|
for offset in offsets:
|
|
for i in range(num_lines // 2):
|
|
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
|
|
|
|
|
|
def process(path):
|
|
if os.getenv('DG_PRINT_REG_REUSE', None):
|
|
print(f'Processing {path}')
|
|
output = run_cuobjdump(path)
|
|
segments = extract_ffma(output)
|
|
with open(path, 'r+b') as f:
|
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
|
|
for segment in segments:
|
|
modify_segment(mm, *segment)
|
|
mm.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
|
|
parser.add_argument('--so', help='Path to the SO file')
|
|
args = parser.parse_args()
|
|
|
|
process(args.so)
|