mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 19:34:22 +00:00
feat: make API more general
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
6c982791eb
commit
46762b6903
@ -1,3 +1,3 @@
|
|||||||
from .compiler import get_nvcc_compiler, build
|
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
|
||||||
from .template import generate
|
from .template import generate
|
||||||
from .runtime import Runtime
|
from .runtime import Runtime
|
||||||
|
@ -7,14 +7,14 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
import cuda.bindings
|
import cuda.bindings
|
||||||
import cuda.bindings.nvrtc as nvrtc
|
import cuda.bindings.nvrtc as nvrtc
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
|
|
||||||
from . import interleave_ffma
|
from . import interleave_ffma
|
||||||
from .runtime import Runtime, RuntimeCache, get_symbol
|
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
|
||||||
|
|
||||||
runtime_cache = RuntimeCache()
|
runtime_cache = RuntimeCache()
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ class Compiler(abc.ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -132,7 +132,7 @@ class Compiler(abc.ABC):
|
|||||||
return [get_jit_include_dir()]
|
return [get_jit_include_dir()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, name: str, code: str) -> Runtime:
|
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||||
# Compiler flags
|
# Compiler flags
|
||||||
flags = cls.flags()
|
flags = cls.flags()
|
||||||
include_dirs = cls.include_dirs()
|
include_dirs = cls.include_dirs()
|
||||||
@ -146,10 +146,11 @@ class Compiler(abc.ABC):
|
|||||||
|
|
||||||
# Check runtime cache or file system hit
|
# Check runtime cache or file system hit
|
||||||
global runtime_cache
|
global runtime_cache
|
||||||
if runtime_cache[path] is not None:
|
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||||
|
if cached_runtime is not None:
|
||||||
if os.getenv('DG_JIT_DEBUG', None):
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
print(f'Using cached JIT runtime {name} during build')
|
print(f'Using cached JIT runtime {name} during build')
|
||||||
return runtime_cache[path]
|
return cached_runtime
|
||||||
|
|
||||||
# Compile into a temporary CU file
|
# Compile into a temporary CU file
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
@ -157,7 +158,7 @@ class Compiler(abc.ABC):
|
|||||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
if os.getenv('DG_JIT_DEBUG', None):
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
@ -176,8 +177,9 @@ class Compiler(abc.ABC):
|
|||||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||||
|
|
||||||
# Put cache and return
|
# Put cache and return
|
||||||
runtime_cache[path] = Runtime(path, kernel_name)
|
runtime = runtime_cls(path, kernel_name)
|
||||||
return runtime_cache[path]
|
runtime_cache[path] = runtime
|
||||||
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
class NvccCompiler(Compiler):
|
class NvccCompiler(Compiler):
|
||||||
@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
|
|||||||
f'--compiler-options={",".join(cxx_flags)}']
|
f'--compiler-options={",".join(cxx_flags)}']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||||
# Write the code
|
# Write the code
|
||||||
path = os.path.join(get_cache_dir(), name)
|
path = os.path.join(get_cache_dir(), name)
|
||||||
src_path = os.path.join(path, 'kernel.cu')
|
src_path = os.path.join(path, 'kernel.cu')
|
||||||
@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
|
|||||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||||
|
|
||||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
return get_symbol(target_path, kernel_name_pattern)
|
||||||
|
|
||||||
|
|
||||||
class NvrtcCompiler(Compiler):
|
class NvrtcCompiler(Compiler):
|
||||||
@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
|
|||||||
return base_flags
|
return base_flags
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||||
code_bytes = bytes(code, 'utf-8')
|
code_bytes = bytes(code, 'utf-8')
|
||||||
res, program = nvrtc.nvrtcCreateProgram(
|
res, program = nvrtc.nvrtcCreateProgram(
|
||||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
raise Exception(f"Failed to create program: {res}")
|
raise Exception(f"Failed to create program: {res}")
|
||||||
|
|
||||||
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
|
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
|
||||||
kernel_name = kernel_regex.search(code).group(
|
kernel_name = kernel_regex.search(code).group(
|
||||||
0).replace('\n', '').replace(' ', '')
|
0).replace('\n', '').replace(' ', '')
|
||||||
res = nvrtc.nvrtcAddNameExpression(
|
res = nvrtc.nvrtcAddNameExpression(
|
||||||
@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
|
|||||||
|
|
||||||
def build(name: str, code: str) -> Runtime:
|
def build(name: str, code: str) -> Runtime:
|
||||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||||
return NvrtcCompiler.build(name, code)
|
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
|
||||||
else:
|
else:
|
||||||
return NvccCompiler.build(name, code)
|
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Callable, Dict, List, Optional, Type
|
||||||
|
|
||||||
import cuda.bindings.driver as cuda
|
import cuda.bindings.driver as cuda
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
@ -13,9 +13,10 @@ from .utils import run_gemm
|
|||||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||||
if CUDA_HOME is None:
|
if CUDA_HOME is None:
|
||||||
raise Exception("CUDA_HOME is not set")
|
raise Exception("CUDA_HOME is not set")
|
||||||
|
|
||||||
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
||||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
|
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
|
||||||
|
'-symbols', file_path]
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE, text=True)
|
stderr=subprocess.PIPE, text=True)
|
||||||
assert result.returncode == 0
|
assert result.returncode == 0
|
||||||
@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, path: str, kernel_name: str) -> None:
|
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.lib = None
|
self.lib = None
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
self.kernel_name = kernel_name
|
self.kernel_name = kernel_name
|
||||||
|
self.caller = caller
|
||||||
|
self.args = args
|
||||||
assert self.is_path_valid(self.path)
|
assert self.is_path_valid(self.path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -62,7 +64,7 @@ class Runtime:
|
|||||||
if self.kernel is not None:
|
if self.kernel is not None:
|
||||||
self.lib = lib
|
self.lib = lib
|
||||||
else:
|
else:
|
||||||
raise Exception("Failed to find fp8 gemm kernel")
|
raise Exception("Failed to find kernel")
|
||||||
|
|
||||||
end_time = time.time_ns()
|
end_time = time.time_ns()
|
||||||
elapsed_time = (end_time - start_time) / 1000
|
elapsed_time = (end_time - start_time) / 1000
|
||||||
@ -70,21 +72,9 @@ class Runtime:
|
|||||||
print(
|
print(
|
||||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||||
|
|
||||||
return run_gemm(
|
return self.caller(
|
||||||
self.kernel,
|
self.kernel,
|
||||||
kwargs['NUM_TMA_MULTICAST'],
|
*[kwargs[arg] for arg in self.args]
|
||||||
kwargs['M'],
|
|
||||||
kwargs['BLOCK_M'],
|
|
||||||
kwargs['GMEM_D'],
|
|
||||||
kwargs['SCALES_B'],
|
|
||||||
kwargs['GROUPED_LAYOUT'],
|
|
||||||
kwargs['NUM_SMS'],
|
|
||||||
kwargs['SMEM_SIZE'],
|
|
||||||
kwargs['TENSOR_MAP_A'],
|
|
||||||
kwargs['TENSOR_MAP_B'],
|
|
||||||
kwargs['TENSOR_MAP_SCALES_A'],
|
|
||||||
kwargs['TENSOR_MAP_D'],
|
|
||||||
kwargs['STREAM'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
@ -94,22 +84,42 @@ class Runtime:
|
|||||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8GemmRuntime(Runtime):
|
||||||
|
def __init__(self, path: str, kernel_name: str) -> None:
|
||||||
|
super().__init__(path, kernel_name, run_gemm, [
|
||||||
|
'NUM_TMA_MULTICAST',
|
||||||
|
'M',
|
||||||
|
'BLOCK_M',
|
||||||
|
'GMEM_D',
|
||||||
|
'SCALES_B',
|
||||||
|
'GROUPED_LAYOUT',
|
||||||
|
'NUM_SMS',
|
||||||
|
'SMEM_SIZE',
|
||||||
|
'TENSOR_MAP_A',
|
||||||
|
'TENSOR_MAP_B',
|
||||||
|
'TENSOR_MAP_SCALES_A',
|
||||||
|
'TENSOR_MAP_D',
|
||||||
|
'STREAM',
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class RuntimeCache:
|
class RuntimeCache:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
def __setitem__(self, path, runtime) -> None:
|
||||||
|
self.cache[path] = runtime
|
||||||
|
|
||||||
|
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||||
# In Python runtime
|
# In Python runtime
|
||||||
if path in self.cache:
|
if path in self.cache:
|
||||||
return self.cache[path]
|
return self.cache[path]
|
||||||
|
|
||||||
# Already compiled
|
# Already compiled
|
||||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||||
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
|
kernel_name = open(os.path.join(
|
||||||
runtime = Runtime(path, kernel_name)
|
path, 'kernel.cubin.name'), 'r').read()
|
||||||
|
runtime = runtime_cls(path, kernel_name)
|
||||||
self.cache[path] = runtime
|
self.cache[path] = runtime
|
||||||
return runtime
|
return runtime
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __setitem__(self, path, runtime) -> None:
|
|
||||||
self.cache[path] = runtime
|
|
@ -1,64 +1,118 @@
|
|||||||
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import cuda.bindings.driver as cuda
|
||||||
|
|
||||||
from deep_gemm import jit
|
from deep_gemm import jit
|
||||||
|
|
||||||
|
|
||||||
class Capture:
|
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
|
||||||
def __init__(self) -> None:
|
assert a.shape == b.shape == c.shape
|
||||||
self.read_fd = None
|
assert a.device == b.device == c.device
|
||||||
self.write_fd = None
|
assert a.dim() == 1
|
||||||
self.saved_stdout = None
|
|
||||||
self.captured = None
|
|
||||||
|
|
||||||
def __enter__(self) -> Any:
|
n = a.numel()
|
||||||
self.read_fd, self.write_fd = os.pipe()
|
|
||||||
self.saved_stdout = os.dup(1)
|
|
||||||
os.dup2(self.write_fd, 1)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
config = cuda.CUlaunchConfig()
|
||||||
os.dup2(self.saved_stdout, 1)
|
config.gridDimX = (n + 127) // 128
|
||||||
os.close(self.write_fd)
|
config.gridDimY = 1
|
||||||
with os.fdopen(self.read_fd, 'r') as f:
|
config.gridDimZ = 1
|
||||||
self.captured = f.read()
|
config.blockDimX = 128
|
||||||
|
config.blockDimY = 1
|
||||||
|
config.blockDimZ = 1
|
||||||
|
config.hStream = stream
|
||||||
|
|
||||||
def capture(self) -> str:
|
kernelValues = (
|
||||||
return self.captured
|
a.data_ptr(),
|
||||||
|
b.data_ptr(),
|
||||||
|
c.data_ptr(),
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
kernelTypes = (
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_uint32,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
|
||||||
|
return f"""
|
||||||
|
#ifdef __CUDACC_RTC__
|
||||||
|
#ifndef NVRTC_JIT_COMPILATION
|
||||||
|
#define NVRTC_JIT_COMPILATION
|
||||||
|
#endif
|
||||||
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
|
#else
|
||||||
|
#include <cuda.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||||
|
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (i < N) {{
|
||||||
|
c[i] = a[i] + b[i];
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
#ifndef NVRTC_JIT_COMPILATION
|
||||||
|
__global__ void dummy_kernel() {{
|
||||||
|
void *ptr = (void *)&vector_add<{kwargs['T']}>;
|
||||||
|
}}
|
||||||
|
#endif
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class VectorAddRuntime(jit.Runtime):
|
||||||
|
def __init__(self, path: str, kernel_name: str) -> None:
|
||||||
|
super().__init__(path, kernel_name, run_vector_add, [
|
||||||
|
'A',
|
||||||
|
'B',
|
||||||
|
'C',
|
||||||
|
'STREAM',
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Runtime
|
# NVCC
|
||||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
|
||||||
|
|
||||||
# Templates
|
|
||||||
print('Generated code:')
|
print('Generated code:')
|
||||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
code = generate_vector_add(T='float')
|
||||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
|
||||||
body = "\n"
|
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
|
||||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
|
||||||
code = jit.generate((), args, body)
|
|
||||||
print(code)
|
print(code)
|
||||||
|
|
||||||
# Build
|
|
||||||
print('Building ...')
|
print('Building ...')
|
||||||
func = jit.build('test_func', args, code)
|
func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime)
|
||||||
|
|
||||||
# Test correctness
|
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||||
print('Running ...')
|
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
c = torch.empty_like(a)
|
||||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||||
with Capture() as capture:
|
ref_output = a + b
|
||||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
torch.testing.assert_close(c, ref_output)
|
||||||
output = capture.capture()
|
|
||||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
|
||||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
|
||||||
|
|
||||||
print('JIT test passed')
|
print('JIT test for NVCC passed\n')
|
||||||
|
|
||||||
|
# NVRTC
|
||||||
|
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
|
||||||
|
print('Generated code:')
|
||||||
|
code = generate_vector_add(T='__nv_bfloat16')
|
||||||
|
print(code)
|
||||||
|
print('Building ...')
|
||||||
|
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime)
|
||||||
|
|
||||||
|
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||||
|
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||||
|
c = torch.empty_like(a)
|
||||||
|
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||||
|
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||||
|
ref_output = a + b
|
||||||
|
torch.testing.assert_close(c, ref_output)
|
||||||
|
|
||||||
|
print('JIT test for NVRTC passed')
|
Loading…
Reference in New Issue
Block a user