mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
refactor: compile to .cubin and add NVRTC option
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
|
|
||||||
// Prefetch TMA descriptors at very beginning
|
// Prefetch TMA descriptors at very beginning
|
||||||
if (threadIdx.x == kNumMathThreads) {
|
if (threadIdx.x == kNumMathThreads) {
|
||||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||||
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||||
|
|
||||||
// `tensor_map_d` is only used in swizzling mode
|
// `tensor_map_d` is only used in swizzling mode
|
||||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||||
if constexpr (kSwizzleDMode > 0)
|
if constexpr (kSwizzleDMode > 0)
|
||||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
@@ -448,129 +448,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
|
||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
|
||||||
uint32_t BLOCK_N_PADDING,
|
|
||||||
uint32_t kSwizzleDMode,
|
|
||||||
uint32_t kNumGroups, uint32_t kNumStages,
|
|
||||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
|
||||||
GemmType kGemmType>
|
|
||||||
class Gemm {
|
|
||||||
private:
|
|
||||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Gemm() = default;
|
|
||||||
|
|
||||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|
||||||
uint32_t shape_m,
|
|
||||||
const CUtensorMap& tma_a_desc,
|
|
||||||
const CUtensorMap& tma_b_desc,
|
|
||||||
const CUtensorMap& tma_scales_a_desc,
|
|
||||||
const CUtensorMap& tma_d_desc,
|
|
||||||
cudaStream_t stream,
|
|
||||||
int num_sms, uint32_t smem_size) {
|
|
||||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
|
||||||
constexpr uint32_t kNumTMAThreads = 128;
|
|
||||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
|
||||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
|
||||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
|
||||||
BLOCK_N_PADDING,
|
|
||||||
kSwizzleDMode,
|
|
||||||
kNumGroups, kNumStages,
|
|
||||||
kNumTMAThreads, kNumMathThreadsPerGroup,
|
|
||||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
|
||||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
|
||||||
|
|
||||||
// Cluster launch
|
|
||||||
cudaLaunchConfig_t config;
|
|
||||||
config.gridDim = num_sms;
|
|
||||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
|
||||||
config.dynamicSmemBytes = smem_size;
|
|
||||||
config.stream = stream;
|
|
||||||
|
|
||||||
// Clusters for TMA multicast
|
|
||||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
|
||||||
cudaLaunchAttribute attr;
|
|
||||||
attr.id = cudaLaunchAttributeClusterDimension;
|
|
||||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
|
||||||
config.attrs = &attr;
|
|
||||||
config.numAttrs = 1;
|
|
||||||
|
|
||||||
// Launch
|
|
||||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
|
||||||
gmem_d, scales_b, grouped_layout,
|
|
||||||
shape_m,
|
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
|
||||||
DG_HOST_ASSERT(status == cudaSuccess);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, GemmType kGemmType>
|
|
||||||
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
|
|
||||||
return make_2d_tma_desc(
|
|
||||||
global_address, Layout::RowMajor,
|
|
||||||
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
|
||||||
shape_k, block_m, block_k);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, GemmType kGemmType>
|
|
||||||
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
|
|
||||||
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
|
|
||||||
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
|
|
||||||
block_k, block_n);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
|
|
||||||
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
|
|
||||||
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
|
||||||
if constexpr (kSwizzleDMode == 32)
|
|
||||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
|
||||||
if constexpr (kSwizzleDMode == 64)
|
|
||||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
|
||||||
if constexpr (kSwizzleDMode == 128)
|
|
||||||
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
|
||||||
|
|
||||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
|
|
||||||
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
|
||||||
return make_2d_tma_desc(
|
|
||||||
global_address, Layout::RowMajor,
|
|
||||||
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
|
||||||
shape_n, block_m,
|
|
||||||
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, GemmType kGemmType>
|
|
||||||
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
|
|
||||||
// Make TMA aligned to 16 bytes
|
|
||||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
|
||||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
|
||||||
|
|
||||||
return make_2d_tma_desc(
|
|
||||||
global_address, Layout::ColMajor, shape_m,
|
|
||||||
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
|
|
||||||
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static CUtensorMap
|
|
||||||
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
|
|
||||||
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
|
|
||||||
CUtensorMapSwizzle swizzle_type =
|
|
||||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
|
||||||
if (layout == Layout::RowMajor) {
|
|
||||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
|
||||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
|
||||||
return make_2d_tma_copy_desc(global_address, gmem_dim,
|
|
||||||
gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
|
||||||
} else {
|
|
||||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
|
||||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
|
||||||
return make_2d_tma_copy_desc(global_address, gmem_dim,
|
|
||||||
gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}; // namespace deep_gemm
|
}; // namespace deep_gemm
|
||||||
|
|
||||||
#pragma clang diagnostic pop
|
#pragma clang diagnostic pop
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@@ -29,41 +29,73 @@ using int64_t = signed long long;
|
|||||||
using uint64_t = unsigned long long;
|
using uint64_t = unsigned long long;
|
||||||
using cuuint64_t = unsigned long long;
|
using cuuint64_t = unsigned long long;
|
||||||
|
|
||||||
namespace std
|
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||||
{
|
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||||
template <class T, T v>
|
|
||||||
struct integral_constant
|
|
||||||
{
|
|
||||||
static constexpr T value = v;
|
|
||||||
using value_type = T;
|
|
||||||
using type = integral_constant; // using injected-class-name
|
|
||||||
|
|
||||||
__device__ constexpr operator value_type() const noexcept
|
struct CUtensorMap_st
|
||||||
{
|
{
|
||||||
return value;
|
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||||
}
|
alignas(64)
|
||||||
|
#elif __STDC_VERSION__ >= 201112L
|
||||||
|
_Alignas(64)
|
||||||
|
#endif
|
||||||
|
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
|
||||||
|
};
|
||||||
|
|
||||||
__device__ constexpr value_type operator()() const noexcept
|
using CUtensorMap = CUtensorMap_st;
|
||||||
{
|
#endif
|
||||||
return value;
|
|
||||||
} // since c++14
|
namespace std {
|
||||||
|
template <class T, T v> struct integral_constant {
|
||||||
|
static constexpr T value = v;
|
||||||
|
using value_type = T;
|
||||||
|
using type = integral_constant; // using injected-class-name
|
||||||
|
|
||||||
|
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||||
|
|
||||||
|
__device__ constexpr value_type operator()() const noexcept {
|
||||||
|
return value;
|
||||||
|
} // since c++14
|
||||||
};
|
};
|
||||||
|
|
||||||
using false_type = integral_constant<bool, false>;
|
using false_type = integral_constant<bool, false>;
|
||||||
using true_type = integral_constant<bool, true>;
|
using true_type = integral_constant<bool, true>;
|
||||||
|
|
||||||
template <class T, class U>
|
template <class T, class U> struct is_same : false_type {};
|
||||||
struct is_same : false_type
|
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class T>
|
template <class T> struct is_same<T, T> : true_type {};
|
||||||
struct is_same<T, T> : true_type
|
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class T, class U>
|
template <class T, class U>
|
||||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||||
|
|
||||||
|
namespace index_sequence_impl {
|
||||||
|
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||||
|
template <size_t... Ints> struct index_sequence {
|
||||||
|
using type = index_sequence;
|
||||||
|
using value_type = size_t;
|
||||||
|
static constexpr size_t size() noexcept { return sizeof...(Ints); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
|
||||||
|
|
||||||
|
template <size_t... I1, size_t... I2>
|
||||||
|
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
|
||||||
|
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
|
||||||
|
|
||||||
|
template <size_t N>
|
||||||
|
struct make_index_sequence
|
||||||
|
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
|
||||||
|
typename make_index_sequence<N - N / 2>::type> {};
|
||||||
|
|
||||||
|
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||||
|
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||||
|
} // namespace index_sequence_impl
|
||||||
|
|
||||||
|
template <size_t... Ns>
|
||||||
|
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||||
|
|
||||||
|
template <size_t N>
|
||||||
|
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -2,87 +2,17 @@
|
|||||||
|
|
||||||
#ifndef NVRTC_JIT_COMPILATION
|
#ifndef NVRTC_JIT_COMPILATION
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cudaTypedefs.h>
|
|
||||||
#include <cuda_fp8.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <cuda/barrier>
|
#include <cuda/barrier>
|
||||||
|
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
|
|
||||||
namespace deep_gemm {
|
namespace deep_gemm {
|
||||||
|
|
||||||
template <class T>
|
|
||||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
|
||||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
|
||||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
|
||||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
|
||||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
|
||||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
|
||||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
|
||||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
|
||||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
|
||||||
} else if constexpr (std::is_same<T, __half>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
|
||||||
} else if constexpr (std::is_same<T, float>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
|
||||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
|
||||||
} else if constexpr (std::is_same<T, double>::value) {
|
|
||||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
|
||||||
// Get pointer to `cuTensorMapEncodeTiled`
|
|
||||||
cudaDriverEntryPointQueryResult driver_status;
|
|
||||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
|
||||||
|
|
||||||
#if CUDA_VERSION >= 12050
|
|
||||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
|
||||||
cudaEnableDefault, &driver_status);
|
|
||||||
#else
|
|
||||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
|
||||||
cudaEnableDefault, &driver_status);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (driver_status != cudaDriverEntryPointSuccess)
|
|
||||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
|
||||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
|
||||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
|
||||||
CUtensorMapSwizzle swizzle_type,
|
|
||||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
|
||||||
CUtensorMap tensor_map = {};
|
|
||||||
uint64_t global_stride[1] = {stride_in_bytes};
|
|
||||||
uint32_t elem_strides[2] = {1, 1};
|
|
||||||
|
|
||||||
if (encode_func == nullptr)
|
|
||||||
encode_func = get_cuTensorMapEncodeTiled();
|
|
||||||
|
|
||||||
auto result = encode_func(
|
|
||||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
|
|
||||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
|
||||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
|
||||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
|
||||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
|
||||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
|
||||||
return tensor_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <uint32_t kNumTMAMulticast = 1>
|
template <uint32_t kNumTMAMulticast = 1>
|
||||||
__device__ __forceinline__ void
|
__device__ __forceinline__ void
|
||||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
import hashlib
|
import abc
|
||||||
import functools
|
import functools
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import cuda.bindings
|
||||||
|
import cuda.bindings.nvrtc as nvrtc
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from . import interleave_ffma
|
from . import interleave_ffma
|
||||||
from .runtime import Runtime, RuntimeCache
|
from .runtime import Runtime, RuntimeCache
|
||||||
@@ -29,7 +33,8 @@ def get_jit_include_dir() -> str:
|
|||||||
def get_deep_gemm_version() -> str:
|
def get_deep_gemm_version() -> str:
|
||||||
# Update include directories
|
# Update include directories
|
||||||
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
include_dir = f'{get_jit_include_dir()}/deep_gemm'
|
||||||
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
|
assert os.path.exists(
|
||||||
|
include_dir), f'Cannot find GEMM include directory {include_dir}'
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
|
||||||
with open(f'{include_dir}/{filename}', 'rb') as f:
|
with open(f'{include_dir}/{filename}', 'rb') as f:
|
||||||
@@ -53,7 +58,8 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
|||||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||||
for path in paths:
|
for path in paths:
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
match = version_pattern.search(os.popen(f'{path} --version').read())
|
match = version_pattern.search(
|
||||||
|
os.popen(f'{path} --version').read())
|
||||||
version = match.group(1)
|
version = match.group(1)
|
||||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||||
@@ -94,64 +100,173 @@ def put(path, data, is_binary=False):
|
|||||||
os.replace(tmp_file_path, path)
|
os.replace(tmp_file_path, path)
|
||||||
|
|
||||||
|
|
||||||
def build(name: str, code: str) -> Runtime:
|
class Compiler(abc.ABC):
|
||||||
# Compiler flags
|
@staticmethod
|
||||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
@abc.abstractmethod
|
||||||
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
def __version__() -> Tuple[int, int]:
|
||||||
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
|
pass
|
||||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
|
||||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
|
||||||
'--diag-suppress=39,174,177,940']
|
|
||||||
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
|
||||||
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
|
|
||||||
include_dirs = [get_jit_include_dir()]
|
|
||||||
|
|
||||||
# Build signature
|
@classmethod
|
||||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
@abc.abstractmethod
|
||||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
def compile(cls, name: str, src_path: str, target_path: str):
|
||||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
pass
|
||||||
path = f'{get_cache_dir()}/{name}'
|
|
||||||
|
|
||||||
# Check runtime cache or file system hit
|
@staticmethod
|
||||||
global runtime_cache
|
def flags() -> List[str]:
|
||||||
if runtime_cache[path] is not None:
|
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||||
|
return [f'-std=c++{cpp_standard}',
|
||||||
|
'--ptxas-options=--register-usage-level=10' +
|
||||||
|
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||||
|
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||||
|
'--diag-suppress=39,174,177,940']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def include_dirs() -> List[str]:
|
||||||
|
return [get_jit_include_dir()]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, name: str, code: str) -> Runtime:
|
||||||
|
# Compiler flags
|
||||||
|
flags = cls.flags()
|
||||||
|
include_dirs = cls.include_dirs()
|
||||||
|
|
||||||
|
# Build signature
|
||||||
|
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
|
||||||
|
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
|
||||||
|
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||||
|
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||||
|
path = f'{get_cache_dir()}/{name}'
|
||||||
|
|
||||||
|
# Check runtime cache or file system hit
|
||||||
|
global runtime_cache
|
||||||
|
if runtime_cache[path] is not None:
|
||||||
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
|
print(f'Using cached JIT runtime {name} during build')
|
||||||
|
return runtime_cache[path]
|
||||||
|
|
||||||
|
# Write the code
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
src_path = f'{path}/kernel.cu'
|
||||||
|
put(src_path, code)
|
||||||
|
|
||||||
|
# Compile into a temporary CU file
|
||||||
|
cubin_path = f'{path}/kernel.cubin'
|
||||||
|
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
cls.compile(name, src_path, tmp_cubin_path)
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
if os.getenv('DG_JIT_DEBUG', None):
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
print(f'Using cached JIT runtime {name} during build')
|
print(
|
||||||
|
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||||
|
|
||||||
|
# Interleave FFMA reuse
|
||||||
|
if enable_sass_opt:
|
||||||
|
interleave_ffma.process(tmp_cubin_path)
|
||||||
|
|
||||||
|
# Atomic replace CU file
|
||||||
|
os.replace(tmp_cubin_path, cubin_path)
|
||||||
|
|
||||||
|
# Put cache and return
|
||||||
|
runtime_cache[path] = Runtime(path)
|
||||||
return runtime_cache[path]
|
return runtime_cache[path]
|
||||||
|
|
||||||
# Write the code
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
src_path = f'{path}/kernel.cu'
|
|
||||||
put(src_path, code)
|
|
||||||
|
|
||||||
# Compile into a temporary CU file
|
class NvccCompiler(Compiler):
|
||||||
cubin_path = f'{path}/kernel.cubin'
|
@staticmethod
|
||||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
def __version__() -> Tuple[int, int]:
|
||||||
|
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
|
||||||
|
return (major, minor)
|
||||||
|
|
||||||
# Compile
|
@classmethod
|
||||||
command = [get_nvcc_compiler()[0],
|
def flags(cls) -> List[str]:
|
||||||
src_path, '-o', tmp_cubin_path,
|
cxx_flags = ['-fPIC', '-O3',
|
||||||
*flags,
|
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||||
*[f'-I{d}' for d in include_dirs]]
|
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
'-gencode=arch=compute_90a,code=sm_90a',
|
||||||
print(f'Compiling JIT runtime {name} with command {command}')
|
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||||
|
f'--compiler-options={",".join(cxx_flags)}']
|
||||||
start_time = time.time()
|
|
||||||
return_code = subprocess.check_call(command)
|
|
||||||
end_time = time.time()
|
|
||||||
assert return_code == 0, f'Failed to compile {src_path}'
|
|
||||||
|
|
||||||
# Print elapsed time if debug is enabled
|
@classmethod
|
||||||
elapsed_time = end_time - start_time
|
def compile(cls, name: str, src_path: str, target_path: str):
|
||||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
command = [get_nvcc_compiler()[0],
|
||||||
|
src_path, '-o', target_path,
|
||||||
|
*cls.flags()]
|
||||||
|
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||||
|
print(f'Compiling JIT runtime {name} with command {command}')
|
||||||
|
|
||||||
# Interleave FFMA reuse
|
return_code = subprocess.check_call(command)
|
||||||
if enable_sass_opt:
|
assert return_code == 0, f'Failed to compile {src_path}'
|
||||||
interleave_ffma.process(tmp_cubin_path)
|
|
||||||
|
|
||||||
# Atomic replace CU file
|
|
||||||
os.replace(tmp_cubin_path, cubin_path)
|
|
||||||
|
|
||||||
# Put cache and return
|
class NvrtcCompiler(Compiler):
|
||||||
runtime_cache[path] = Runtime(path)
|
@staticmethod
|
||||||
return runtime_cache[path]
|
def __version__() -> Tuple[int, int]:
|
||||||
|
_, version = get_nvcc_compiler()
|
||||||
|
major, minor = map(int, version.split('.'))
|
||||||
|
return (major, minor)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def include_dirs() -> List[str]:
|
||||||
|
if CUDA_HOME is None:
|
||||||
|
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
|
||||||
|
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def flags(cls) -> List[str]:
|
||||||
|
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||||
|
'--gpu-architecture=sm_90a', '-default-device']
|
||||||
|
if cls.__version__() >= (12, 8):
|
||||||
|
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
|
||||||
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
|
base_flags += ['--pch-verbose=true']
|
||||||
|
return base_flags
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compile(cls, name: str, src_path: str, target_path: str):
|
||||||
|
code_bytes = open(src_path, 'rb').read()
|
||||||
|
res, program = nvrtc.nvrtcCreateProgram(
|
||||||
|
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to create program: {res}")
|
||||||
|
|
||||||
|
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
|
||||||
|
compile_res = nvrtc.nvrtcCompileProgram(
|
||||||
|
program, len(options), options)[0]
|
||||||
|
|
||||||
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
|
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to get program log size: {res}")
|
||||||
|
log_bytes = bytes(log_size)
|
||||||
|
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to get program log: {res}")
|
||||||
|
log_str = log_bytes.decode('utf-8')
|
||||||
|
print(log_str)
|
||||||
|
|
||||||
|
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to compile program: {compile_res}")
|
||||||
|
|
||||||
|
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to get CUBIN size: {res}")
|
||||||
|
|
||||||
|
cubin_bytes = bytes(cubin_size)
|
||||||
|
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to get CUBIN: {res}")
|
||||||
|
|
||||||
|
put(target_path, cubin_bytes, is_binary=True)
|
||||||
|
|
||||||
|
res = nvrtc.nvrtcDestroyProgram(program)[0]
|
||||||
|
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||||
|
raise Exception(f"Failed to destroy program: {res}")
|
||||||
|
|
||||||
|
|
||||||
|
def build(name: str, code: str) -> Runtime:
|
||||||
|
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||||
|
return NvrtcCompiler.build(name, code)
|
||||||
|
else:
|
||||||
|
return NvccCompiler.build(name, code)
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ import time
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import cuda.bindings.driver as cuda
|
import cuda.bindings.driver as cuda
|
||||||
import cuda.bindings.nvrtc as nvrtc
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .utils import run_gemm
|
from .utils import run_gemm
|
||||||
|
|
||||||
@@ -58,8 +56,9 @@ class Runtime:
|
|||||||
|
|
||||||
end_time = time.time_ns()
|
end_time = time.time_ns()
|
||||||
elapsed_time = (end_time - start_time) / 1000
|
elapsed_time = (end_time - start_time) / 1000
|
||||||
print(
|
if os.getenv('DG_JIT_DEBUG', None):
|
||||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
print(
|
||||||
|
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||||
|
|
||||||
return run_gemm(
|
return run_gemm(
|
||||||
self.kernel,
|
self.kernel,
|
||||||
|
|||||||
Reference in New Issue
Block a user