mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
commit
a2e0d68eed
@ -109,7 +109,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize barriers
|
// Initialize barriers
|
||||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
|
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
||||||
if (threadIdx.x == kNumMathThreads) {
|
if (threadIdx.x == kNumMathThreads) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kNumStages; ++ i) {
|
for (int i = 0; i < kNumStages; ++ i) {
|
||||||
@ -406,7 +406,8 @@ public:
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, BLOCK_M, BLOCK_N,
|
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||||
|
min(BLOCK_M, shape_m), BLOCK_N,
|
||||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def extract_ffma(sass):
|
|||||||
current = []
|
current = []
|
||||||
|
|
||||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||||
print(f"Found {len(collected)} FFMA segments")
|
print(f'Found {len(collected)} FFMA segments')
|
||||||
return collected
|
return collected
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,6 @@ def validate(m, offset, le_bytes, num_lines):
|
|||||||
|
|
||||||
|
|
||||||
def parse_registers(line):
|
def parse_registers(line):
|
||||||
import re
|
|
||||||
line = re.sub(r'/\*.*?\*/', '', line)
|
line = re.sub(r'/\*.*?\*/', '', line)
|
||||||
line = line.replace(';', '')
|
line = line.replace(';', '')
|
||||||
tokens = line.strip().split(',')
|
tokens = line.strip().split(',')
|
||||||
@ -92,7 +91,7 @@ def modify_segment(m, name, ffma_lines):
|
|||||||
is_first_occurred = dst_reg not in dst_reg_set
|
is_first_occurred = dst_reg not in dst_reg_set
|
||||||
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
|
||||||
# Modify the `reuse` and `yield` bits
|
# Modify the `reuse` and `yield` bits
|
||||||
assert high_hex & 0x0800200000000000, f"{hex(high_hex)}"
|
assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
|
||||||
high_hex ^= 0x0800200000000000
|
high_hex ^= 0x0800200000000000
|
||||||
reused = False
|
reused = False
|
||||||
num_changed += 1
|
num_changed += 1
|
||||||
@ -102,7 +101,7 @@ def modify_segment(m, name, ffma_lines):
|
|||||||
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
|
||||||
last_reused, last_dst_reg = reused, dst_reg
|
last_reused, last_dst_reg = reused, dst_reg
|
||||||
if os.getenv('DG_PRINT_REG_REUSE', None):
|
if os.getenv('DG_PRINT_REG_REUSE', None):
|
||||||
print(f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}")
|
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
|
||||||
|
|
||||||
# Find the offset
|
# Find the offset
|
||||||
offsets = []
|
offsets = []
|
||||||
@ -130,7 +129,7 @@ def process(path):
|
|||||||
mm.close()
|
mm.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
|
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
|
||||||
parser.add_argument('--so', help='Path to the SO file')
|
parser.add_argument('--so', help='Path to the SO file')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -79,10 +79,12 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
elif num_waves < best_num_waves:
|
elif num_waves < best_num_waves:
|
||||||
success = True
|
success = True
|
||||||
elif num_waves == best_num_waves:
|
elif num_waves == best_num_waves:
|
||||||
|
div_n = bool(128 % block_n)
|
||||||
|
best_div_n = bool(128 % best_block_n)
|
||||||
# Check last wave utilization
|
# Check last wave utilization
|
||||||
util = get_last_wave_util(block_m, block_n)
|
util = get_last_wave_util(block_m, block_n)
|
||||||
best_util = get_last_wave_util(best_block_m, best_block_n)
|
best_util = get_last_wave_util(best_block_m, best_block_n)
|
||||||
success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m))
|
success = util > best_util or (util == best_util and (block_m > best_block_m or block_m == best_block_m and (div_n < best_div_n or div_n == best_div_n and block_n < best_block_n)))
|
||||||
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||||
assert best_block_m is not None and best_block_n is not None
|
assert best_block_m is not None and best_block_n is not None
|
||||||
|
|
||||||
|
@ -160,6 +160,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||||
|
|
||||||
|
# Extra checks for TMA store
|
||||||
|
if num_groups > 1 and m > block_m:
|
||||||
|
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
|
||||||
|
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
masked_m, m,
|
masked_m, m,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), num_sms, smem_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user