Refactor MMA with cutlass

This commit is contained in:
Zhean Xu 2025-04-14 16:56:13 +08:00
parent 37aa127451
commit 584b67eebb

View File

@ -2,871 +2,13 @@
#include <cuda.h>
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include "utils.cuh"
namespace deep_gemm {
struct SM90_64x16x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
" %8,"
" %9,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 16;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x24x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %14, 0;\n"
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11},"
" %12,"
" %13,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 24;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x32x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15},"
" %16,"
" %17,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 32;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x40x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 40;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x48x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23},"
" %24,"
" %25,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 48;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x56x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %30, 0;\n"
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27}, "
" %28,"
" %29,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 56;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x64x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31}, "
" %32,"
" %33,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 64;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x72x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %38, 0;\n"
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35}, "
" %36,"
" %37,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 72;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x80x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %42, 0;\n"
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39}, "
" %40,"
" %41,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 80;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x88x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %46, 0;\n"
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43}, "
" %44,"
" %45,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 88;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x96x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %50, 0;\n"
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47}, "
" %48,"
" %49,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 96;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x104x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %54, 0;\n"
"wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51}, "
" %52,"
" %53,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 104;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x112x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %58, 0;\n"
"wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55}, "
" %56,"
" %57,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 112;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x120x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %62, 0;\n"
"wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59}, "
" %60,"
" %61,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 120;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x128x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %66, 0;\n"
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63}, "
" %64,"
" %65,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 128;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x144x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %74, 0;\n"
"wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71}, "
" %72,"
" %73,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 144;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x160x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %82, 0;\n"
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71, "
" %72, %73, %74, %75, %76, %77, %78, %79}, "
" %80,"
" %81,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 160;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x192x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47,
float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55,
float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63,
float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71,
float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79,
float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87,
float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %98, 0;\n"
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63, "
" %64, %65, %66, %67, %68, %69, %70, %71, "
" %72, %73, %74, %75, %76, %77, %78, %79, "
" %80, %81, %82, %83, %84, %85, %86, %87, "
" %88, %89, %90, %91, %92, %93, %94, %95}, "
" %96,"
" %97,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47],
d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55],
d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63],
d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71],
d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79],
d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87],
d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 192;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void
@ -987,27 +129,53 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
return desc;
}
template <int N_, typename MMA>
struct FP8MMA {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
template <int N>
struct FP8MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
}
static constexpr auto select_type() {
if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS();
if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS();
if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS();
if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS();
if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS();
if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS();
if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS();
if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS();
if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS();
if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS();
if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS();
if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS();
if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS();
if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS();
if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS();
if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS();
if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS();
if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS();
return FP8MMA<N, decltype(select_mma())>();
}
using type = decltype(select_type());