mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor JIT compilation (+NVRTC support) (#94)
* [wip] refactor: compile to .cubin Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * refactor: compile to .cubin and add NVRTC option Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: compiler version Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: compat for old drivers Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: save kernel name to file Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: fix win compat Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * fix: windows compat Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> * feat: make API more general Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * feat: drop support for CUDA<12.3 Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * doc: update README Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> * Some lints and refactor * Refactor runtime * Several fixes * Refactor environment variables * Code format * Add a TODO * Compatible with CUDA 12.3 * Fix indent * Fix typing * Drop support for Windows * Add a TODO --------- Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -84,16 +84,17 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
|
||||
// `tensor_map_d` is only used in swizzling mode
|
||||
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
|
||||
if constexpr (kSwizzleDMode > 0)
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -447,119 +448,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t BLOCK_N_PADDING,
|
||||
uint32_t kSwizzleDMode,
|
||||
uint32_t kNumGroups, uint32_t kNumStages,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
class Gemm {
|
||||
private:
|
||||
using Barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
|
||||
public:
|
||||
Gemm() = default;
|
||||
|
||||
static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
uint32_t shape_m,
|
||||
const CUtensorMap& tma_a_desc,
|
||||
const CUtensorMap& tma_b_desc,
|
||||
const CUtensorMap& tma_scales_a_desc,
|
||||
const CUtensorMap& tma_d_desc,
|
||||
cudaStream_t stream,
|
||||
int num_sms, uint32_t smem_size) {
|
||||
// NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps
|
||||
constexpr uint32_t kNumTMAThreads = 128;
|
||||
constexpr uint32_t kNumMathThreadsPerGroup = 128;
|
||||
auto kernel = fp8_gemm_kernel<SHAPE_N, SHAPE_K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
BLOCK_N_PADDING,
|
||||
kSwizzleDMode,
|
||||
kNumGroups, kNumStages,
|
||||
kNumTMAThreads, kNumMathThreadsPerGroup,
|
||||
kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>;
|
||||
DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess);
|
||||
|
||||
// Cluster launch
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_sms;
|
||||
config.blockDim = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
|
||||
// Clusters for TMA multicast
|
||||
// NOTES: `>= 4` cluster size will cause performance degradation
|
||||
cudaLaunchAttribute attr;
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {kNumTMAMulticast, 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
|
||||
// Launch
|
||||
auto status = cudaLaunchKernelEx(&config, kernel,
|
||||
gmem_d, scales_b, grouped_layout,
|
||||
shape_m,
|
||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
|
||||
DG_HOST_ASSERT(status == cudaSuccess);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
|
||||
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
|
||||
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
|
||||
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout::RowMajor,
|
||||
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
|
||||
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
|
||||
// Make TMA aligned to 16 bytes
|
||||
constexpr uint32_t kAlignment = 16 / sizeof(T);
|
||||
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout::ColMajor,
|
||||
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static CUtensorMap make_2d_tma_desc(
|
||||
T* global_address, Layout layout,
|
||||
uint32_t gmem_rows, uint32_t gmem_cols,
|
||||
uint32_t smem_rows, uint32_t smem_cols,
|
||||
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
|
||||
if (layout == Layout::RowMajor) {
|
||||
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
|
||||
uint32_t smem_dim[2] = {smem_cols, smem_rows};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
|
||||
} else {
|
||||
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
|
||||
uint32_t smem_dim[2] = {smem_rows, smem_cols};
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef __CUDACC_RTC__
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
||||
|
||||
103
deep_gemm/include/deep_gemm/nvrtc_std.cuh
Normal file
103
deep_gemm/include/deep_gemm/nvrtc_std.cuh
Normal file
@@ -0,0 +1,103 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __CUDACC_RTC__
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
using int16_t = signed short;
|
||||
using uint16_t = unsigned short;
|
||||
using int32_t = signed int;
|
||||
using uint32_t = unsigned int;
|
||||
using int64_t = signed long long;
|
||||
using uint64_t = unsigned long long;
|
||||
using cuuint64_t = unsigned long long;
|
||||
|
||||
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||
|
||||
struct CUtensorMap_st {
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
alignas(64)
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
_Alignas(64)
|
||||
#endif
|
||||
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
|
||||
};
|
||||
|
||||
using CUtensorMap = CUtensorMap_st;
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
|
||||
template <class T, T v> struct integral_constant {
|
||||
static constexpr T value = v;
|
||||
|
||||
using value_type = T;
|
||||
using type = integral_constant;
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
using true_type = integral_constant<bool, true>;
|
||||
|
||||
template <class T, class U> struct is_same : false_type {};
|
||||
|
||||
template <class T> struct is_same<T, T> : true_type {};
|
||||
|
||||
template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
namespace index_sequence_impl {
|
||||
|
||||
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||
template <size_t... Ints> struct index_sequence {
|
||||
using type = index_sequence;
|
||||
using value_type = size_t;
|
||||
static constexpr size_t size() noexcept { return sizeof...(Ints); }
|
||||
};
|
||||
|
||||
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
|
||||
|
||||
template <size_t... I1, size_t... I2>
|
||||
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
|
||||
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
|
||||
|
||||
template <size_t N>
|
||||
struct make_index_sequence
|
||||
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
|
||||
typename make_index_sequence<N - N / 2>::type> {};
|
||||
|
||||
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||
|
||||
} // namespace index_sequence_impl
|
||||
|
||||
template <size_t... Ns>
|
||||
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
@@ -46,7 +46,7 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
|
||||
@@ -63,7 +63,8 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx,
|
||||
uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
|
||||
@@ -1,85 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <class T>
|
||||
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if constexpr (std::is_same<T, uint64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
} else if constexpr (std::is_same<T, int64_t>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_INT64;
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
} else if constexpr (std::is_same<T, float>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
|
||||
}
|
||||
}
|
||||
|
||||
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
// Get pointer to `cuTensorMapEncodeTiled`
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
void* cuTensorMapEncodeTiled_ptr = nullptr;
|
||||
|
||||
#if CUDA_VERSION >= 12050
|
||||
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#else
|
||||
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
|
||||
cudaEnableDefault, &driver_status);
|
||||
#endif
|
||||
|
||||
if (driver_status != cudaDriverEntryPointSuccess)
|
||||
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
|
||||
uint64_t stride_in_bytes, uint32_t smem_dim[2],
|
||||
CUtensorMapSwizzle swizzle_type,
|
||||
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
|
||||
CUtensorMap tensor_map = {};
|
||||
uint64_t global_stride[1] = {stride_in_bytes};
|
||||
uint32_t elem_strides[2] = {1, 1};
|
||||
|
||||
if (encode_func == nullptr)
|
||||
encode_func = get_cuTensorMapEncodeTiled();
|
||||
|
||||
auto result = encode_func(
|
||||
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
|
||||
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
DG_HOST_ASSERT(result == CUDA_SUCCESS);
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
// TODO: move this function to other files
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
||||
|
||||
@@ -1,33 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
||||
asm volatile("trap;");
|
||||
}
|
||||
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
|
||||
Reference in New Issue
Block a user