From 3b412f458a4e2cf816fc51d8d3a7e1a8549f915c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 15 May 2025 16:53:52 +0800 Subject: [PATCH] Unify `kwargs` usages --- deep_gemm/jit/runtime.py | 7 +++---- deep_gemm/jit_kernels/gemm.py | 4 ++-- deep_gemm/jit_kernels/m_grouped_gemm.py | 8 ++++---- deep_gemm/jit_kernels/runtime.py | 4 ++-- deep_gemm/jit_kernels/wgrad_gemm.py | 4 ++-- tests/test_jit.py | 2 +- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index ffcd0b3..52af8e1 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,4 +1,3 @@ -import copy import os import subprocess import time @@ -27,14 +26,14 @@ class Runtime: return all(os.path.exists(os.path.join(path, file)) for file in files) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: raise NotImplemented @staticmethod - def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: raise NotImplemented - def __call__(self, **kwargs) -> cbd.CUresult: + def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult: # Load CUBIN if self.kernel is None: start_time = time.time_ns() diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 5f7a123..9cb01c3 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -237,6 +237,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], } # Generate, build and run the kernel - code = FP8GemmRuntime.generate(**kwargs) + code = FP8GemmRuntime.generate(kwargs) runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index c2f2d93..b072060 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -101,9 +101,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten } # Generate, build and run the kernel - code = FP8GemmRuntime.generate(**kwargs) + code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -200,6 +200,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] } # Generate, build and run the kernel - code = FP8GemmRuntime.generate(**kwargs) + code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index a7b0e66..e65e85a 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -138,7 +138,7 @@ class FP8GemmRuntime(Runtime): super().__init__(path) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: code = f''' #ifdef __CUDACC_RTC__ #include @@ -233,7 +233,7 @@ class FP8WGradGemmRuntime(Runtime): super().__init__(path) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: code = f''' #ifdef __CUDACC_RTC__ #include diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py index 658f005..d0655cc 100644 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ b/deep_gemm/jit_kernels/wgrad_gemm.py @@ -108,9 +108,9 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], } # Generate, build and run the kernel - code = FP8WGradGemmRuntime.generate(**kwargs) + code = FP8WGradGemmRuntime.generate(kwargs) runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) - runtime(**kwargs) + runtime(kwargs) def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], diff --git a/tests/test_jit.py b/tests/test_jit.py index a1bf583..413bd01 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -16,7 +16,7 @@ class VectorAddRuntime(jit.Runtime): super().__init__(path) @staticmethod - def generate(**kwargs) -> str: + def generate(kwargs: Dict[str, Any]) -> str: return f""" #ifdef __CUDACC_RTC__ #include