Move flash_mla.h to kernels/params.h

This commit is contained in:
Shengyu Liu 2025-04-22 17:46:35 +08:00
parent c7123cb36e
commit 828a19c720
7 changed files with 6 additions and 5 deletions

View File

@ -13,6 +13,7 @@
#include "kernels/config.h"
#include "kernels/get_mla_metadata.h"
#include "kernels/mla_combine.h"
#include "kernels/params.h"
#include "kernels/splitkv_mla.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")

View File

@ -1,5 +1,5 @@
#pragma once
#include "flash_mla.h"
#include "params.h"
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream);

View File

@ -5,7 +5,7 @@
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "flash_mla.h"
#include "params.h"
#include "utils.h"
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V

View File

@ -1,6 +1,6 @@
#pragma once
#include "flash_mla.h"
#include "params.h"
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -1,6 +1,6 @@
#include <cutlass/cutlass.h>
#include "flash_mla.h"
#include "params.h"
#include "utils.h"
#include "config.h"
#include "traits.h"

View File

@ -1,6 +1,6 @@
#pragma once
#include "flash_mla.h"
#include "params.h"
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);