mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
237 lines
8.8 KiB
Plaintext
237 lines
8.8 KiB
Plaintext
#pragma once
|
|
|
|
#ifndef __CUDACC_RTC__
|
|
#include <cuda.h>
|
|
#endif
|
|
|
|
#include <cute/arch/mma_sm90_gmma.hpp>
|
|
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
|
|
|
#include "utils.cuh"
|
|
|
|
namespace deep_gemm {
|
|
|
|
template <typename dtype_t>
|
|
struct SM90_U32x2_STSM_N {
|
|
__device__ __forceinline__ static void
|
|
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
|
|
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
|
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
|
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
|
|
}
|
|
};
|
|
|
|
template <typename dtype_t>
|
|
struct SM90_U32x4_STSM_N {
|
|
__device__ __forceinline__ static void
|
|
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
|
|
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
|
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
|
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
|
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
|
}
|
|
};
|
|
|
|
template <typename dtype_t>
|
|
struct SM90_U32x2_STSM_T
|
|
{
|
|
__device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst)
|
|
{
|
|
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
|
|
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]),
|
|
"r"(src[1]));
|
|
}
|
|
};
|
|
|
|
template <typename dtype_t>
|
|
struct SM90_U32x4_STSM_T
|
|
{
|
|
__device__ __forceinline__ static void copy(
|
|
dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst)
|
|
{
|
|
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
|
|
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
|
|
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst),
|
|
"r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
|
|
}
|
|
};
|
|
|
|
__forceinline__ __device__ void warpgroup_arrive() {
|
|
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
|
}
|
|
|
|
__forceinline__ __device__ void warpgroup_commit_batch() {
|
|
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
|
}
|
|
|
|
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
|
|
asm volatile("" : "+f"(reg) :: "memory");
|
|
}
|
|
|
|
__forceinline__ __device__ uint32_t get_lane_id() {
|
|
uint32_t lane_id;
|
|
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
|
|
return lane_id;
|
|
}
|
|
|
|
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
|
|
uint32_t ret;
|
|
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
|
return ret;
|
|
}
|
|
|
|
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
|
|
int4 ret;
|
|
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
|
return ret;
|
|
}
|
|
|
|
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
|
|
float ret;
|
|
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
|
return ret;
|
|
}
|
|
|
|
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
|
|
float2 ret;
|
|
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
|
|
return ret;
|
|
}
|
|
|
|
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
|
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
|
}
|
|
|
|
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
|
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
|
}
|
|
|
|
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
|
|
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
|
|
}
|
|
|
|
template <int N>
|
|
__device__ void warpgroup_wait() {
|
|
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
|
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
|
}
|
|
|
|
union GmmaDescriptor {
|
|
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
|
|
|
|
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
|
|
|
|
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
|
|
|
|
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
|
|
|
|
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
|
|
desc_ = t.desc_;
|
|
return *this;
|
|
}
|
|
|
|
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
|
|
desc_ = t.desc_;
|
|
return *this;
|
|
}
|
|
|
|
uint64_t desc_;
|
|
uint32_t reg32_[2];
|
|
uint16_t reg16_[4];
|
|
|
|
struct {
|
|
uint16_t start_address_: 14, : 2;
|
|
uint16_t leading_byte_offset_: 14, : 2;
|
|
uint16_t stride_byte_offset_: 14, : 2;
|
|
uint8_t : 1, base_offset_: 3, : 4;
|
|
uint8_t : 6, layout_type_: 2;
|
|
} bitfield;
|
|
|
|
// Decay to an `uint64_t`
|
|
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
|
|
};
|
|
|
|
template <class PointerType>
|
|
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
|
|
int leading_byte_offset = 0,
|
|
int stride_byte_offset = 1024) {
|
|
GmmaDescriptor desc;
|
|
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
desc.bitfield.start_address_ = uint_ptr >> 4;
|
|
desc.bitfield.layout_type_ = layout_type;
|
|
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
|
|
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
|
|
desc.bitfield.base_offset_ = 0;
|
|
return desc;
|
|
}
|
|
|
|
template <int N_, typename MMA>
|
|
struct FP8MMA {
|
|
|
|
template <size_t ...Idx>
|
|
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence<Idx...>) {
|
|
using namespace cute::SM90::GMMA;
|
|
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
|
}
|
|
|
|
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
|
call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence<N_/2>{});
|
|
}
|
|
|
|
static constexpr int M = 64;
|
|
static constexpr int N = N_;
|
|
static constexpr int K = 32;
|
|
static constexpr int kNumAccum = M * N / 128;
|
|
};
|
|
|
|
template <int N>
|
|
struct FP8MMASelector {
|
|
|
|
static constexpr auto select_mma() {
|
|
using namespace cute::SM90::GMMA;
|
|
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
|
|
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
|
|
}
|
|
|
|
static constexpr auto select_type() {
|
|
return FP8MMA<N, decltype(select_mma())>();
|
|
}
|
|
|
|
using type = decltype(select_type());
|
|
};
|
|
|
|
enum class Layout {
|
|
RowMajor,
|
|
ColMajor
|
|
};
|
|
|
|
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
|
|
return block_m == 64 ? 1 : 2;
|
|
}
|
|
|
|
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
|
|
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
|
|
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
|
|
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
|
|
}
|
|
|
|
} // namespace deep_gemm
|