mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Compatible with CUDA 12.3
This commit is contained in:
@@ -12,7 +12,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', [
|
||||
super().__init__(path, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
@@ -31,17 +31,15 @@ class VectorAddRuntime(jit.Runtime):
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
template <typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
|
||||
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < N) {{
|
||||
if (i < n) {{
|
||||
c[i] = a[i] + b[i];
|
||||
}}
|
||||
}}
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||
}}
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<float>);
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
|
||||
Reference in New Issue
Block a user