mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Move flash_mla.h to kernels/params.h
This commit is contained in:
parent
c7123cb36e
commit
828a19c720
@ -13,6 +13,7 @@
|
||||
#include "kernels/config.h"
|
||||
#include "kernels/get_mla_metadata.h"
|
||||
#include "kernels/mla_combine.h"
|
||||
#include "kernels/params.h"
|
||||
#include "kernels/splitkv_mla.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
|
@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "params.h"
|
||||
|
||||
void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "params.h"
|
||||
#include "utils.h"
|
||||
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "params.h"
|
||||
|
||||
template<typename ElementT>
|
||||
void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "params.h"
|
||||
#include "utils.h"
|
||||
#include "config.h"
|
||||
#include "traits.h"
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "params.h"
|
||||
|
||||
template<typename InputT>
|
||||
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
Loading…
Reference in New Issue
Block a user