refactor: compile to .cubin and add NVRTC option

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 10:17:52 +00:00
parent 27cd276e19
commit c14cad0c06
5 changed files with 237 additions and 284 deletions

View File

@@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Prefetch TMA descriptors at very beginning // Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) { if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
cute::prefetch_tma_descriptor(&tensor_map_scales_a); cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
// `tensor_map_d` is only used in swizzling mode // `tensor_map_d` is only used in swizzling mode
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
if constexpr (kSwizzleDMode > 0) if constexpr (kSwizzleDMode > 0)
cute::prefetch_tma_descriptor(&tensor_map_d); cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
} }
__syncwarp(); __syncwarp();
@@ -448,129 +448,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#endif #endif
} }
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kSwizzleDMode,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
class Gemm {
private:
using Barrier = cuda::barrier<cuda::thread_scope_block>;
public:
Gemm() = default;
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
uint32_t shape_m,
const CUtensorMap& tma_a_desc,
const CUtensorMap& tma_b_desc,
const CUtensorMap& tma_scales_a_desc,
const CUtensorMap& tma_d_desc,
cudaStream_t stream,
int num_sms, uint32_t smem_size) {
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
BLOCK_M, BLOCK_N, BLOCK_K,
BLOCK_N_PADDING,
kSwizzleDMode,
kNumGroups, kNumStages,
kNumTMAThreads, kNumMathThreadsPerGroup,
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
// Cluster launch
cudaLaunchConfig_t config;
config.gridDim = num_sms;
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Clusters for TMA multicast
// NOTES: `>= 4` cluster size will cause performance degradation
cudaLaunchAttribute attr;
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
// Launch
auto status = cudaLaunchKernelEx(&config, kernel,
gmem_d, scales_b, grouped_layout,
shape_m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
};
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_k, block_m, block_k);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
block_k, block_n);
}
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if constexpr (kSwizzleDMode == 32)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_n, block_m,
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(
global_address, Layout::ColMajor, shape_m,
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type =
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
}; // namespace deep_gemm }; // namespace deep_gemm
#pragma clang diagnostic pop #pragma clang diagnostic pop

View File

@@ -1,6 +1,6 @@
/* /*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0 * All rights reserved. SPDX-License-Identifier: Apache-2.0
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -29,41 +29,73 @@ using int64_t = signed long long;
using uint64_t = unsigned long long; using uint64_t = unsigned long long;
using cuuint64_t = unsigned long long; using cuuint64_t = unsigned long long;
namespace std #ifndef CU_TENSOR_MAP_NUM_QWORDS
{ #define CU_TENSOR_MAP_NUM_QWORDS 16
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
__device__ constexpr operator value_type() const noexcept struct CUtensorMap_st
{ {
return value; #if defined(__cplusplus) && (__cplusplus >= 201103L)
} alignas(64)
#elif __STDC_VERSION__ >= 201112L
_Alignas(64)
#endif
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
};
__device__ constexpr value_type operator()() const noexcept using CUtensorMap = CUtensorMap_st;
{ #endif
return value;
} // since c++14 namespace std {
template <class T, T v> struct integral_constant {
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
__device__ constexpr operator value_type() const noexcept { return value; }
__device__ constexpr value_type operator()() const noexcept {
return value;
} // since c++14
}; };
using false_type = integral_constant<bool, false>; using false_type = integral_constant<bool, false>;
using true_type = integral_constant<bool, true>; using true_type = integral_constant<bool, true>;
template <class T, class U> template <class T, class U> struct is_same : false_type {};
struct is_same : false_type
{
};
template <class T> template <class T> struct is_same<T, T> : true_type {};
struct is_same<T, T> : true_type
{
};
template <class T, class U> template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value; inline constexpr bool is_same_v = is_same<T, U>::value;
namespace index_sequence_impl {
// Based on https://stackoverflow.com/a/32223343/11717224
template <size_t... Ints> struct index_sequence {
using type = index_sequence;
using value_type = size_t;
static constexpr size_t size() noexcept { return sizeof...(Ints); }
};
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
template <size_t... I1, size_t... I2>
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
template <size_t N>
struct make_index_sequence
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
typename make_index_sequence<N - N / 2>::type> {};
template <> struct make_index_sequence<0> : index_sequence<> {};
template <> struct make_index_sequence<1> : index_sequence<0> {};
} // namespace index_sequence_impl
template <size_t... Ns>
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
template <size_t N>
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
} // namespace std } // namespace std
#endif #endif

View File

@@ -2,87 +2,17 @@
#ifndef NVRTC_JIT_COMPILATION #ifndef NVRTC_JIT_COMPILATION
#include <cassert> #include <cassert>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudaTypedefs.h>
#endif #endif
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/barrier> #include <cuda/barrier>
#include "utils.cuh" #include "utils.cuh"
namespace deep_gemm { namespace deep_gemm {
template <class T>
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
if constexpr (std::is_same<T, uint8_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, uint16_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
} else if constexpr (std::is_same<T, uint32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
} else if constexpr (std::is_same<T, uint64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
} else if constexpr (std::is_same<T, int32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT32;
} else if constexpr (std::is_same<T, int64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT64;
} else if constexpr (std::is_same<T, __half>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if constexpr (std::is_same<T, float>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if constexpr (std::is_same<T, double>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
}
}
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;
#if CUDA_VERSION >= 12050
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
cudaEnableDefault, &driver_status);
#else
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
cudaEnableDefault, &driver_status);
#endif
if (driver_status != cudaDriverEntryPointSuccess)
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
}
template <typename T>
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map = {};
uint64_t global_stride[1] = {stride_in_bytes};
uint32_t elem_strides[2] = {1, 1};
if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();
auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1> template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void __device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,

View File

@@ -1,12 +1,16 @@
import hashlib import abc
import functools import functools
import hashlib
import os import os
import re import re
import subprocess import subprocess
import time import time
import uuid import uuid
from typing import List, Tuple
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma from . import interleave_ffma
from .runtime import Runtime, RuntimeCache from .runtime import Runtime, RuntimeCache
@@ -29,7 +33,8 @@ def get_jit_include_dir() -> str:
def get_deep_gemm_version() -> str: def get_deep_gemm_version() -> str:
# Update include directories # Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm' include_dir = f'{get_jit_include_dir()}/deep_gemm'
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5() md5 = hashlib.md5()
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f: with open(f'{include_dir}/{filename}', 'rb') as f:
@@ -53,7 +58,8 @@ def get_nvcc_compiler() -> Tuple[str, str]:
version_pattern = re.compile(r'release (\d+\.\d+)') version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths: for path in paths:
if os.path.exists(path): if os.path.exists(path):
match = version_pattern.search(os.popen(f'{path} --version').read()) match = version_pattern.search(
os.popen(f'{path} --version').read())
version = match.group(1) version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}' assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@@ -94,64 +100,173 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path) os.replace(tmp_file_path, path)
def build(name: str, code: str) -> Runtime: class Compiler(abc.ABC):
# Compiler flags @staticmethod
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) @abc.abstractmethod
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', def __version__() -> Tuple[int, int]:
'-gencode=arch=compute_90a,code=sm_90a', '-cubin', pass
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]
# Build signature @classmethod
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 @abc.abstractmethod
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' def compile(cls, name: str, src_path: str, target_path: str):
name = f'kernel.{name}.{hash_to_hex(signature)}' pass
path = f'{get_cache_dir()}/{name}'
# Check runtime cache or file system hit @staticmethod
global runtime_cache def flags() -> List[str]:
if runtime_cache[path] is not None: cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
@staticmethod
def include_dirs() -> List[str]:
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str) -> Runtime:
# Compiler flags
flags = cls.flags()
include_dirs = cls.include_dirs()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
start_time = time.time()
cls.compile(name, src_path, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build') print(
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path] return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file class NvccCompiler(Compiler):
cubin_path = f'{path}/kernel.cubin' @staticmethod
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin' def __version__() -> Tuple[int, int]:
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
return (major, minor)
# Compile @classmethod
command = [get_nvcc_compiler()[0], def flags(cls) -> List[str]:
src_path, '-o', tmp_cubin_path, cxx_flags = ['-fPIC', '-O3',
*flags, '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
*[f'-I{d}' for d in include_dirs]] return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): '-gencode=arch=compute_90a,code=sm_90a',
print(f'Compiling JIT runtime {name} with command {command}') '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
start_time = time.time()
return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}'
# Print elapsed time if debug is enabled @classmethod
elapsed_time = end_time - start_time def compile(cls, name: str, src_path: str, target_path: str):
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
# Interleave FFMA reuse return_code = subprocess.check_call(command)
if enable_sass_opt: assert return_code == 0, f'Failed to compile {src_path}'
interleave_ffma.process(tmp_cubin_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return class NvrtcCompiler(Compiler):
runtime_cache[path] = Runtime(path) @staticmethod
return runtime_cache[path] def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return (major, minor)
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
@classmethod
def flags(cls) -> List[str]:
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8):
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
if os.getenv('DG_JIT_DEBUG', None):
base_flags += ['--pch-verbose=true']
return base_flags
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
code_bytes = open(src_path, 'rb').read()
res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0]
if os.getenv('DG_JIT_DEBUG', None):
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log size: {res}")
log_bytes = bytes(log_size)
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log: {res}")
log_str = log_bytes.decode('utf-8')
print(log_str)
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}")
cubin_bytes = bytes(cubin_size)
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}")
put(target_path, cubin_bytes, is_binary=True)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}")
def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code)
else:
return NvccCompiler.build(name, code)

View File

@@ -3,8 +3,6 @@ import time
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
import torch
from .utils import run_gemm from .utils import run_gemm
@@ -58,8 +56,9 @@ class Runtime:
end_time = time.time_ns() end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000 elapsed_time = (end_time - start_time) / 1000
print( if os.getenv('DG_JIT_DEBUG', None):
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm( return run_gemm(
self.kernel, self.kernel,