mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
refactor: compile to .cubin and add NVRTC option
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
27cd276e19
commit
c14cad0c06
@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
|
||||
// `tensor_map_d` is only used in swizzling mode
|
||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||
if constexpr (kSwizzleDMode > 0)
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@ -448,129 +448,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
uint32_t kSwizzleDMode,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
BLOCK_N_PADDING,
|
||||
kSwizzleDMode,
|
||||
kNumGroups, kNumStages,
|
||||
kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, GemmType kGemmType>
|
||||
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
|
||||
return make_2d_tma_desc(
|
||||
global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
||||
shape_k, block_m, block_k);
|
||||
}
|
||||
|
||||
template <typename T, GemmType kGemmType>
|
||||
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
|
||||
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
|
||||
block_k, block_n);
|
||||
}
|
||||
|
||||
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
|
||||
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
|
||||
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleDMode == 32)
|
||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
if constexpr (kSwizzleDMode == 64)
|
||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
if constexpr (kSwizzleDMode == 128)
|
||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
|
||||
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(
|
||||
global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
||||
shape_n, block_m,
|
||||
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
|
||||
}
|
||||
|
||||
template <typename T, GemmType kGemmType>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(
|
||||
global_address, Layout::ColMajor, shape_m,
|
||||
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
||||
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap
|
||||
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
|
||||
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type =
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim,
|
||||
gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim,
|
||||
gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -29,41 +29,73 @@ using int64_t = signed long long;
|
||||
using uint64_t = unsigned long long;
|
||||
using cuuint64_t = unsigned long long;
|
||||
|
||||
namespace std
|
||||
{
|
||||
template <class T, T v>
|
||||
struct integral_constant
|
||||
{
|
||||
static constexpr T value = v;
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept
|
||||
{
|
||||
return value;
|
||||
}
|
||||
struct CUtensorMap_st
|
||||
{
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
alignas(64)
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
_Alignas(64)
|
||||
#endif
|
||||
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
|
||||
};
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept
|
||||
{
|
||||
return value;
|
||||
} // since c++14
|
||||
using CUtensorMap = CUtensorMap_st;
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
template <class T, T v> struct integral_constant {
|
||||
static constexpr T value = v;
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept {
|
||||
return value;
|
||||
} // since c++14
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
using true_type = integral_constant<bool, true>;
|
||||
|
||||
template <class T, class U>
|
||||
struct is_same : false_type
|
||||
{
|
||||
};
|
||||
template <class T, class U> struct is_same : false_type {};
|
||||
|
||||
template <class T>
|
||||
struct is_same<T, T> : true_type
|
||||
{
|
||||
};
|
||||
template <class T> struct is_same<T, T> : true_type {};
|
||||
|
||||
template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
namespace index_sequence_impl {
|
||||
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||
template <size_t... Ints> struct index_sequence {
|
||||
using type = index_sequence;
|
||||
using value_type = size_t;
|
||||
static constexpr size_t size() noexcept { return sizeof...(Ints); }
|
||||
};
|
||||
|
||||
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
|
||||
|
||||
template <size_t... I1, size_t... I2>
|
||||
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
|
||||
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
|
||||
|
||||
template <size_t N>
|
||||
struct make_index_sequence
|
||||
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
|
||||
typename make_index_sequence<N - N / 2>::type> {};
|
||||
|
||||
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||
} // namespace index_sequence_impl
|
||||
|
||||
template <size_t... Ns>
|
||||
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
|
||||
@ -2,87 +2,17 @@
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#endif
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map = {};
|
||||
uint64_t global_stride[1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[2] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
template <uint32_t kNumTMAMulticast = 1>
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
import hashlib
|
||||
import abc
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
@ -29,7 +33,8 @@ def get_jit_include_dir() -> str:
|
||||
def get_deep_gemm_version() -> str:
|
||||
# Update include directories
|
||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
assert os.path.exists(
|
||||
include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||
md5 = hashlib.md5()
|
||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||
@ -53,7 +58,8 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
||||
match = version_pattern.search(
|
||||
os.popen(f'{path} --version').read())
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
@ -94,64 +100,173 @@ def put(path, data, is_binary=False):
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
||||
include_dirs = [get_jit_include_dir()]
|
||||
class Compiler(abc.ABC):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
pass
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
@staticmethod
|
||||
def flags() -> List[str]:
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
return [f'-std=c++{cpp_standard}',
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
include_dirs = cls.include_dirs()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
|
||||
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = f'{get_cache_dir()}/{name}'
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary CU file
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
|
||||
start_time = time.time()
|
||||
cls.compile(name, src_path, tmp_cubin_path)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
print(
|
||||
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Atomic replace CU file
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary CU file
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
class NvccCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
|
||||
return (major, minor)
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_cubin_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
start_time = time.time()
|
||||
return_code = subprocess.check_call(command)
|
||||
end_time = time.time()
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
cxx_flags = ['-fPIC', '-O3',
|
||||
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
# Print elapsed time if debug is enabled
|
||||
elapsed_time = end_time - start_time
|
||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
@classmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', target_path,
|
||||
*cls.flags()]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Atomic replace CU file
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
return runtime_cache[path]
|
||||
class NvrtcCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
_, version = get_nvcc_compiler()
|
||||
major, minor = map(int, version.split('.'))
|
||||
return (major, minor)
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
if cls.__version__() >= (12, 8):
|
||||
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
base_flags += ['--pch-verbose=true']
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, src_path: str, target_path: str):
|
||||
code_bytes = open(src_path, 'rb').read()
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||
compile_res = nvrtc.nvrtcCompileProgram(
|
||||
program, len(options), options)[0]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get program log size: {res}")
|
||||
log_bytes = bytes(log_size)
|
||||
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get program log: {res}")
|
||||
log_str = log_bytes.decode('utf-8')
|
||||
print(log_str)
|
||||
|
||||
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to compile program: {compile_res}")
|
||||
|
||||
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN size: {res}")
|
||||
|
||||
cubin_bytes = bytes(cubin_size)
|
||||
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to get CUBIN: {res}")
|
||||
|
||||
put(target_path, cubin_bytes, is_binary=True)
|
||||
|
||||
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to destroy program: {res}")
|
||||
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code)
|
||||
else:
|
||||
return NvccCompiler.build(name, code)
|
||||
|
||||
@ -3,8 +3,6 @@ import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
import torch
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
@ -58,8 +56,9 @@ class Runtime:
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
self.kernel,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user