refactor: compile to .cubin and add NVRTC option

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-22 10:17:52 +00:00
parent 27cd276e19
commit c14cad0c06
5 changed files with 237 additions and 284 deletions

View File

@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
// `tensor_map_d` is only used in swizzling mode
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
if constexpr (kSwizzleDMode > 0)
cute::prefetch_tma_descriptor(&tensor_map_d);
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
}
__syncwarp();
@ -448,129 +448,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#endif
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kSwizzleDMode,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
class Gemm {
private:
using Barrier = cuda::barrier<cuda::thread_scope_block>;
public:
Gemm() = default;
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
BLOCK_M, BLOCK_N, BLOCK_K,
BLOCK_N_PADDING,
kSwizzleDMode,
kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
// Launch
auto status = cudaLaunchKernelEx(&config, kernel,
gmem_d, scales_b, grouped_layout,
shape_m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
};
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_k, block_m, block_k);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
block_k, block_n);
}
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if constexpr (kSwizzleDMode == 32)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_n, block_m,
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(
global_address, Layout::ColMajor, shape_m,
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type =
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,41 +29,73 @@ using int64_t = signed long long;
using uint64_t = unsigned long long;
using cuuint64_t = unsigned long long;
namespace std
{
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
#ifndef CU_TENSOR_MAP_NUM_QWORDS
#define CU_TENSOR_MAP_NUM_QWORDS 16
__device__ constexpr operator value_type() const noexcept
{
return value;
}
struct CUtensorMap_st
{
#if defined(__cplusplus) && (__cplusplus >= 201103L)
alignas(64)
#elif __STDC_VERSION__ >= 201112L
_Alignas(64)
#endif
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
};
__device__ constexpr value_type operator()() const noexcept
{
return value;
} // since c++14
using CUtensorMap = CUtensorMap_st;
#endif
namespace std {
template <class T, T v> struct integral_constant {
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
__device__ constexpr operator value_type() const noexcept { return value; }
__device__ constexpr value_type operator()() const noexcept {
return value;
} // since c++14
};
using false_type = integral_constant<bool, false>;
using true_type = integral_constant<bool, true>;
template <class T, class U>
struct is_same : false_type
{
};
template <class T, class U> struct is_same : false_type {};
template <class T>
struct is_same<T, T> : true_type
{
};
template <class T> struct is_same<T, T> : true_type {};
template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value;
namespace index_sequence_impl {
// Based on https://stackoverflow.com/a/32223343/11717224
template <size_t... Ints> struct index_sequence {
using type = index_sequence;
using value_type = size_t;
static constexpr size_t size() noexcept { return sizeof...(Ints); }
};
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
template <size_t... I1, size_t... I2>
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
template <size_t N>
struct make_index_sequence
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
typename make_index_sequence<N - N / 2>::type> {};
template <> struct make_index_sequence<0> : index_sequence<> {};
template <> struct make_index_sequence<1> : index_sequence<0> {};
} // namespace index_sequence_impl
template <size_t... Ns>
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
template <size_t N>
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
} // namespace std
#endif

View File

@ -2,87 +2,17 @@
#ifndef NVRTC_JIT_COMPILATION
#include <cassert>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudaTypedefs.h>
#endif
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/barrier>
#include "utils.cuh"
namespace deep_gemm {
template <class T>
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
if constexpr (std::is_same<T, uint8_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, uint16_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
} else if constexpr (std::is_same<T, uint32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
} else if constexpr (std::is_same<T, uint64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
} else if constexpr (std::is_same<T, int32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT32;
} else if constexpr (std::is_same<T, int64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT64;
} else if constexpr (std::is_same<T, __half>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if constexpr (std::is_same<T, float>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if constexpr (std::is_same<T, double>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
}
}
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;
#if CUDA_VERSION >= 12050
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
cudaEnableDefault, &driver_status);
#else
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
cudaEnableDefault, &driver_status);
#endif
if (driver_status != cudaDriverEntryPointSuccess)
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
}
template <typename T>
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map = {};
uint64_t global_stride[1] = {stride_in_bytes};
uint32_t elem_strides[2] = {1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,

View File

@ -1,12 +1,16 @@
import hashlib
import abc
import functools
import hashlib
import os
import re
import subprocess
import time
import uuid
from typing import List, Tuple
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
@ -29,7 +33,8 @@ def get_jit_include_dir() -> str:
def get_deep_gemm_version() -> str:
# Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm'
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f:
@ -53,7 +58,8 @@ def get_nvcc_compiler() -> Tuple[str, str]:
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
match = version_pattern.search(os.popen(f'{path} --version').read())
match = version_pattern.search(
os.popen(f'{path} --version').read())
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@ -94,64 +100,173 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path)
def build(name: str, code: str) -> Runtime:
# Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]
class Compiler(abc.ABC):
@staticmethod
@abc.abstractmethod
def __version__() -> Tuple[int, int]:
pass
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
@classmethod
@abc.abstractmethod
def compile(cls, name: str, src_path: str, target_path: str):
pass
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
@staticmethod
def flags() -> List[str]:
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
@staticmethod
def include_dirs() -> List[str]:
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str) -> Runtime:
# Compiler flags
flags = cls.flags()
include_dirs = cls.include_dirs()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
start_time = time.time()
cls.compile(name, src_path, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
print(
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
class NvccCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
return (major, minor)
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_cubin_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
start_time = time.time()
return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}'
@classmethod
def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3',
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
# Print elapsed time if debug is enabled
elapsed_time = end_time - start_time
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}'
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]
class NvrtcCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return (major, minor)
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
@classmethod
def flags(cls) -> List[str]:
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8):
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
if os.getenv('DG_JIT_DEBUG', None):
base_flags += ['--pch-verbose=true']
return base_flags
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
code_bytes = open(src_path, 'rb').read()
res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0]
if os.getenv('DG_JIT_DEBUG', None):
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log size: {res}")
log_bytes = bytes(log_size)
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log: {res}")
log_str = log_bytes.decode('utf-8')
print(log_str)
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}")
cubin_bytes = bytes(cubin_size)
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}")
put(target_path, cubin_bytes, is_binary=True)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}")
def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code)
else:
return NvccCompiler.build(name, code)

View File

@ -3,8 +3,6 @@ import time
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
import torch
from .utils import run_gemm
@ -58,8 +56,9 @@ class Runtime:
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
self.kernel,