#pragma once #include "configs.cuh" #ifndef SETUP_LAUNCH_CONFIG #define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ cudaLaunchAttribute attr[1]; \ attr[0].id = cudaLaunchAttributeCooperative; \ attr[0].val.cooperative = 1; \ cfg.attrs = attr; \ cfg.numAttrs = 1 #endif #ifndef LAUNCH_KERNEL #define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) #endif #define SWITCH_RANKS(case_macro) \ switch (num_ranks) { \ case 2: case_macro(2); \ case 4: case_macro(4); \ case 8: case_macro(8); \ default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ } while (false) #define SWITCH_RDMA_RANKS(case_macro) \ switch (num_ranks / NUM_MAX_NVL_PEERS) { \ case 2: case_macro(2); \ case 3: case_macro(3); \ case 4: case_macro(4); \ case 8: case_macro(8); \ case 16: case_macro(16); \ case 18: case_macro(18); \ case 20: case_macro(20); \ default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ } while (false) #define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \ switch (num_ranks) { \ case 2: case_macro(dtype, 2); \ case 4: case_macro(dtype, 4); \ case 8: case_macro(dtype, 8); \ default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ } while (false) #define SWITCH_TYPES(case_macro) \ switch (type) { \ case CUDA_R_16BF: case_macro(nv_bfloat16); \ case CUDA_R_32F: case_macro(float); \ default: EP_HOST_ASSERT(false && "Unsupported type"); \ } while (false) #define SWITCH_HIDDEN(case_macro) \ switch (hidden) { \ case 2560: case_macro(2560); \ case 5120: case_macro(5120); \ case 7168: case_macro(7168); \ default: EP_HOST_ASSERT(false && "Unsupported hidden"); \ } while (false)