feat: drop support for CUDA<12.3

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-25 18:56:40 -07:00
parent 46762b6903
commit f6198492cb
5 changed files with 46 additions and 85 deletions

View File

@@ -62,17 +62,15 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
}}
}}
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>;
}}
#endif
"""
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_vector_add, [
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', run_vector_add, [
'A',
'B',
'C',
@@ -87,7 +85,7 @@ if __name__ == '__main__':
code = generate_vector_add(T='float')
print(code)
print('Building ...')
func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime)
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
@@ -105,7 +103,7 @@ if __name__ == '__main__':
code = generate_vector_add(T='__nv_bfloat16')
print(code)
print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime)
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')