mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:15:31 +08:00
[Kernel] Remove marlin moe templating on thread_m_blocks (#8573)
Co-authored-by: lwilkinson@neuralmagic.com
This commit is contained in:
parent
0d47bf3bf4
commit
4c34ce8916
@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle(
|
|||||||
|
|
||||||
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
|
||||||
// dimension (batchsize) of the
|
|
||||||
// threadblock
|
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
|
|||||||
|
|
||||||
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
|
||||||
// dimension (batchsize) of the
|
|
||||||
// threadblock
|
|
||||||
const int thread_n_blocks, // same for n dimension (output)
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
const int thread_k_blocks, // same for k dimension (reduction)
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
@ -1515,20 +1509,18 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
|||||||
static constexpr int min_thread_n = 64;
|
static constexpr int min_thread_n = 64;
|
||||||
static constexpr int min_thread_k = 64;
|
static constexpr int min_thread_k = 64;
|
||||||
|
|
||||||
#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
|
||||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \
|
GROUP_BLOCKS, NUM_THREADS) \
|
||||||
NUM_THREADS) \
|
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||||
num_threads == NUM_THREADS) { \
|
num_threads == NUM_THREADS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
||||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
||||||
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||||
@ -1712,30 +1704,15 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||||
\
|
|
||||||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
||||||
\
|
|
||||||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
||||||
\
|
|
||||||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
|
||||||
|
|
||||||
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
||||||
const void* sorted_ids, const void* topk_weights,
|
const void* sorted_ids, const void* topk_weights,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user