diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index b242261..0cc554a 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -888,15 +888,15 @@ struct SM90_U32x4_STSM_N { } }; -__device__ void warpgroup_arrive() { +__forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } -__device__ void warpgroup_commit_batch() { +__forceinline__ __device__ void warpgroup_commit_batch() { asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); } -__device__ void warpgroup_fence_operand(float& reg) { +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { asm volatile("" : "+f"(reg) :: "memory"); } diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index d3db1b6..e7bc241 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -40,7 +40,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType() { } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { +inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { // Get pointer to `cuTensorMapEncodeTiled` cudaDriverEntryPointQueryResult driver_status; void* cuTensorMapEncodeTiled_ptr = nullptr;