Compatible with CUDA 12.3

This commit is contained in:
Chenggang Zhao
2025-05-07 11:15:19 +08:00
parent 5373da7b28
commit ba349d9cf8
3 changed files with 30 additions and 47 deletions

View File

@@ -12,7 +12,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', [
super().__init__(path, [
'A',
'B',
'C',
@@ -31,17 +31,15 @@ class VectorAddRuntime(jit.Runtime):
#include <cuda_fp8.h>
#include <cuda_bf16.h>
template<typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
template <typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < N) {{
if (i < n) {{
c[i] = a[i] + b[i];
}}
}}
__global__ void dummy_kernel() {{
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
}}
auto ptr = reinterpret_cast<void*>(&vector_add<float>);
"""
# noinspection PyShadowingNames,PyMethodOverriding