mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Unify kwargs usages
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
import copy
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -27,14 +26,14 @@ class Runtime:
|
|||||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
|
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> cbd.CUresult:
|
def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||||
# Load CUBIN
|
# Load CUBIN
|
||||||
if self.kernel is None:
|
if self.kernel is None:
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
|
|||||||
@@ -237,6 +237,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Generate, build and run the kernel
|
# Generate, build and run the kernel
|
||||||
code = FP8GemmRuntime.generate(**kwargs)
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||||
runtime(**kwargs)
|
runtime(kwargs)
|
||||||
|
|||||||
@@ -101,9 +101,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Generate, build and run the kernel
|
# Generate, build and run the kernel
|
||||||
code = FP8GemmRuntime.generate(**kwargs)
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||||
runtime(**kwargs)
|
runtime(kwargs)
|
||||||
|
|
||||||
|
|
||||||
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -200,6 +200,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Generate, build and run the kernel
|
# Generate, build and run the kernel
|
||||||
code = FP8GemmRuntime.generate(**kwargs)
|
code = FP8GemmRuntime.generate(kwargs)
|
||||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||||
runtime(**kwargs)
|
runtime(kwargs)
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class FP8GemmRuntime(Runtime):
|
|||||||
super().__init__(path)
|
super().__init__(path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
code = f'''
|
code = f'''
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
@@ -233,7 +233,7 @@ class FP8WGradGemmRuntime(Runtime):
|
|||||||
super().__init__(path)
|
super().__init__(path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
code = f'''
|
code = f'''
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
|
|||||||
@@ -108,9 +108,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Generate, build and run the kernel
|
# Generate, build and run the kernel
|
||||||
code = FP8WGradGemmRuntime.generate(**kwargs)
|
code = FP8WGradGemmRuntime.generate(kwargs)
|
||||||
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
|
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
|
||||||
runtime(**kwargs)
|
runtime(kwargs)
|
||||||
|
|
||||||
|
|
||||||
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class VectorAddRuntime(jit.Runtime):
|
|||||||
super().__init__(path)
|
super().__init__(path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(**kwargs) -> str:
|
def generate(kwargs: Dict[str, Any]) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
|
|||||||
Reference in New Issue
Block a user