mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Some lints and refactor
This commit is contained in:
@@ -84,8 +84,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_id();
|
||||
|
||||
// Prefetch TMA descriptors at very beginning
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#ifndef __CUDACC_RTC__
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
#ifdef __CUDACC_RTC__
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
@@ -32,8 +32,7 @@ using cuuint64_t = unsigned long long;
|
||||
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||
|
||||
struct CUtensorMap_st
|
||||
{
|
||||
struct CUtensorMap_st {
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
alignas(64)
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
@@ -46,16 +45,16 @@ using CUtensorMap = CUtensorMap_st;
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
|
||||
template <class T, T v> struct integral_constant {
|
||||
static constexpr T value = v;
|
||||
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
using type = integral_constant;
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept {
|
||||
return value;
|
||||
} // since c++14
|
||||
__device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
@@ -69,6 +68,7 @@ template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
namespace index_sequence_impl {
|
||||
|
||||
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||
template <size_t... Ints> struct index_sequence {
|
||||
using type = index_sequence;
|
||||
@@ -89,6 +89,7 @@ struct make_index_sequence
|
||||
|
||||
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||
|
||||
} // namespace index_sequence_impl
|
||||
|
||||
template <size_t... Ns>
|
||||
@@ -96,6 +97,7 @@ using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
|
||||
@@ -46,7 +46,7 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
|
||||
@@ -63,7 +63,8 @@ struct Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx,
|
||||
uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
|
||||
|
||||
// Swizzle for better L2 usages
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudaTypedefs.h>
|
||||
#endif
|
||||
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// TODO: move this function to other files
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
|
||||
|
||||
@@ -1,39 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#include <exception>
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
|
||||
asm volatile("trap;");
|
||||
}
|
||||
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
class AssertionException : public std::exception {
|
||||
private:
|
||||
std::string message{};
|
||||
|
||||
public:
|
||||
explicit AssertionException(const std::string& message) : message(message) {}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#ifdef NVRTC_JIT_COMPILATION
|
||||
#define DG_HOST_ASSERT(cond) ((void)0)
|
||||
#else
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", \
|
||||
__FILE__, __LINE__, #cond); \
|
||||
throw AssertionException("Assertion failed: " #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef DG_DEVICE_ASSERT
|
||||
#define DG_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
|
||||
Reference in New Issue
Block a user