diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index b9dfe9f..4357920 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -36,7 +36,7 @@ template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) -fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, +fp8_gemm_kernel(float* scales_b, int* grouped_layout, uint32_t shape_m, const __grid_constant__ CUtensorMap tensor_map_a, const __grid_constant__ CUtensorMap tensor_map_b, diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 8fb1a28..1ac0fe1 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -247,7 +247,6 @@ static void __instantiate_kernel() {{ config.hStream = stream arg_values = ( - gmem_d.data_ptr(), scales_b.data_ptr(), grouped_layout.data_ptr(), shape_m, @@ -257,7 +256,6 @@ static void __instantiate_kernel() {{ tensor_map_d, ) arg_types = ( - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint32, diff --git a/indexing/main.cu b/indexing/main.cu index 8e86a30..5b15256 100644 --- a/indexing/main.cu +++ b/indexing/main.cu @@ -1,30 +1,8 @@ #include "deep_gemm/fp8_gemm.cuh" +#include "deep_gemm/fp8_wgrad_gemm.cuh" using namespace deep_gemm; int main() { - int m = 128; - constexpr int N = 4096; - constexpr int K = 7168; - - constexpr int BLOCK_M = 128; - constexpr int BLOCK_N = 128; - constexpr int BLOCK_K = 128; - constexpr int BLOCK_N_PADDING = 0; - constexpr int kSwizzleDMode = 0; - constexpr int kNumGroups = 1; - constexpr int kNumStages = 5; - constexpr int kNumTMAMulticast = 1; - constexpr bool kIsTMAMulticastOnA = false; - - using gemm_t = Gemm; - auto tma_a_desc = gemm_t::make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(0), m); - auto tma_b_desc = gemm_t::make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(0)); - auto tma_d_desc = gemm_t::make_2d_tma_d_desc(reinterpret_cast(0), m); - auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(reinterpret_cast(0), m); - gemm_t::run(nullptr, nullptr, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - nullptr, 132, 0); return 0; }