mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Cleanup some useless staffs
This commit is contained in:
@@ -36,7 +36,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
|||||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||||
GemmType kGemmType>
|
GemmType kGemmType>
|
||||||
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
|
||||||
fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
fp8_gemm_kernel(float* scales_b, int* grouped_layout,
|
||||||
uint32_t shape_m,
|
uint32_t shape_m,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||||
|
|||||||
@@ -247,7 +247,6 @@ static void __instantiate_kernel() {{
|
|||||||
config.hStream = stream
|
config.hStream = stream
|
||||||
|
|
||||||
arg_values = (
|
arg_values = (
|
||||||
gmem_d.data_ptr(),
|
|
||||||
scales_b.data_ptr(),
|
scales_b.data_ptr(),
|
||||||
grouped_layout.data_ptr(),
|
grouped_layout.data_ptr(),
|
||||||
shape_m,
|
shape_m,
|
||||||
@@ -257,7 +256,6 @@ static void __instantiate_kernel() {{
|
|||||||
tensor_map_d,
|
tensor_map_d,
|
||||||
)
|
)
|
||||||
arg_types = (
|
arg_types = (
|
||||||
ctypes.c_void_p,
|
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
ctypes.c_uint32,
|
ctypes.c_uint32,
|
||||||
|
|||||||
@@ -1,30 +1,8 @@
|
|||||||
#include "deep_gemm/fp8_gemm.cuh"
|
#include "deep_gemm/fp8_gemm.cuh"
|
||||||
|
#include "deep_gemm/fp8_wgrad_gemm.cuh"
|
||||||
|
|
||||||
using namespace deep_gemm;
|
using namespace deep_gemm;
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
int m = 128;
|
|
||||||
constexpr int N = 4096;
|
|
||||||
constexpr int K = 7168;
|
|
||||||
|
|
||||||
constexpr int BLOCK_M = 128;
|
|
||||||
constexpr int BLOCK_N = 128;
|
|
||||||
constexpr int BLOCK_K = 128;
|
|
||||||
constexpr int BLOCK_N_PADDING = 0;
|
|
||||||
constexpr int kSwizzleDMode = 0;
|
|
||||||
constexpr int kNumGroups = 1;
|
|
||||||
constexpr int kNumStages = 5;
|
|
||||||
constexpr int kNumTMAMulticast = 1;
|
|
||||||
constexpr bool kIsTMAMulticastOnA = false;
|
|
||||||
|
|
||||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
|
||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m);
|
|
||||||
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0));
|
|
||||||
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast<nv_bfloat16*>(0), m);
|
|
||||||
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast<float*>(0), m);
|
|
||||||
gemm_t::run(nullptr, nullptr, nullptr,
|
|
||||||
m,
|
|
||||||
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
|
|
||||||
nullptr, 132, 0);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user