mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Several cleanups
This commit is contained in:
@@ -23,11 +23,11 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
|
||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||
fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_scales_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
@@ -352,7 +352,7 @@ fp8_wgrad_gemm_kernel(uint32_t shape_k,
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0)
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -36,19 +36,20 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||
|
||||
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
|
||||
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
|
||||
assert block_k == 128
|
||||
|
||||
# Try swizzle first, as it does not waste shared memory
|
||||
swizzle_mode = get_swizzle_mode(block_n)
|
||||
block_n_padding = get_block_n_padding_for_smem_d(
|
||||
block_n) if swizzle_mode == 0 else 0
|
||||
|
||||
# NOTES: `scales_b` in a total manner or per-stage manner
|
||||
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
|
||||
smem_a_per_stage = block_m * block_k
|
||||
smem_scales_a_per_stage = block_m * 4
|
||||
smem_b_per_stage = block_n * block_k
|
||||
if is_wgrad:
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, 128) * 128
|
||||
else:
|
||||
smem_scales_b = ceil_div(k, block_k) * 4
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
|
||||
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
|
||||
smem_barrier = num_stages * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
@@ -56,11 +57,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_size += num_stages * smem_a_per_stage
|
||||
smem_size += num_stages * smem_scales_a_per_stage
|
||||
smem_size += num_stages * smem_b_per_stage
|
||||
if is_wgrad:
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
else:
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k %
|
||||
block_n == 0 else 2), 8) * 8
|
||||
smem_size += num_stages * smem_scales_b_per_stage
|
||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||
smem_size += smem_barrier
|
||||
|
||||
# Swizzle and padding are not compatible
|
||||
@@ -80,7 +78,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
|
||||
|
||||
# Avoid bank conflicts for fp32 output
|
||||
# Avoid bank conflicts for FP32 output
|
||||
if is_fp32_out:
|
||||
block_ns = [x for x in block_ns if x % 16 == 8]
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ static void __instantiate_kernel() {{
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
|
||||
|
||||
|
||||
class FP8WgradGemmRuntime(Runtime):
|
||||
class FP8WGradGemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
@@ -320,7 +320,7 @@ static void __instantiate_kernel() {{
|
||||
}};
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 Wgrad GEMM code:\n{code}')
|
||||
print(f'Generated FP8 WGrad GEMM code:\n{code}')
|
||||
return code
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from .runtime import (
|
||||
FP8WgradGemmRuntime, GemmType,
|
||||
FP8WGradGemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc, make_2d_tma_scales_b_desc)
|
||||
from .gemm import get_best_configs
|
||||
@@ -122,7 +121,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
runtime_cls=FP8WgradGemmRuntime,
|
||||
runtime_cls=FP8WGradGemmRuntime,
|
||||
)
|
||||
|
||||
# Run the kernel
|
||||
|
||||
@@ -78,7 +78,8 @@ class suppress_stdout_stderr:
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, is_multiple: bool = False):
|
||||
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
|
||||
|
||||
@@ -119,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not is_multiple:
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
@@ -131,28 +132,18 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
if not is_multiple:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
else:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user