Compatible with CUDA 12.3

This commit is contained in:
Chenggang Zhao
2025-05-07 11:15:19 +08:00
parent 5373da7b28
commit ba349d9cf8
3 changed files with 30 additions and 47 deletions

View File

@@ -142,7 +142,7 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', [
super().__init__(path, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
@@ -175,8 +175,7 @@ class FP8GemmRuntime(Runtime):
using namespace deep_gemm;
__global__ void dummy_kernel() {{
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
@@ -192,7 +191,6 @@ __global__ void dummy_kernel() {{
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>);
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 GEMM code:\n{code}')