mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Compatible with CUDA 12.3
This commit is contained in:
@@ -142,7 +142,7 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', [
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
@@ -175,8 +175,7 @@ class FP8GemmRuntime(Runtime):
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
@@ -192,7 +191,6 @@ __global__ void dummy_kernel() {{
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>);
|
||||
}}
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 GEMM code:\n{code}')
|
||||
|
||||
Reference in New Issue
Block a user