Merge branch 'main' into mlm-full-lora-support

This commit is contained in:
Jee Jee Li 2025-12-17 12:41:12 +08:00 committed by GitHub
commit fe104bd63c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
75 changed files with 1708 additions and 809 deletions

View File

@ -654,7 +654,7 @@ steps:
- vllm/model_executor/layers/quantization
autorun_on_main: true
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
- label: OpenAI API correctness # 22min
timeout_in_minutes: 30
@ -1064,7 +1064,7 @@ steps:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
##### 1 GPU test #####
##### multi gpus test #####

14
.github/mergify.yml vendored
View File

@ -235,6 +235,20 @@ pull_request_rules:
add:
- rocm
- name: label-cpu
description: Automatically apply cpu label
conditions:
- label != stale
- files~=^(?!.*kv_offload)(?!.*cpu_offload).*\bcpu.*
actions:
label:
add:
- cpu
assign:
users:
- "fadara01"
- "aditew01"
- name: label-structured-output
description: Automatically apply structured-output label
conditions:

View File

@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
# marlin arches for other files
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)
if (MARLIN_OTHER_ARCHS)
#
# For the Marlin kernels we automatically generate sources for various
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
if (MARLIN_ARCHS)
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
endif()
if (MARLIN_SM75_ARCHS)
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}")
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
set_source_files_properties(${MARLIN_SRCS}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
else()
message(STATUS "Not building Marlin kernels as no compatible archs found"
" in CUDA target architectures")
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
# moe marlin arches for other files
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_OTHER_ARCHS)
#
# For the Marlin MOE kernels we automatically generate sources for various
@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif()
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
if (MARLIN_MOE_ARCHS)
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
endif()
if (MARLIN_MOE_SM75_ARCHS)
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SM75_SRC}"
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_OTHER_SRC}"
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
" in CUDA target architectures")

View File

@ -13,8 +13,8 @@ from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
batch_size_range = [1, 16, 128]
seq_len_range = [1, 16, 64, 1024, 4096]
intermediate_size = [3072, 9728, 12288]
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))

View File

@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
VLLM_LDG(&y_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
}
}
@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(xp[j], param) * yp[j];
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = ACT_FN(x, param) * y;
}
}
}
template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
float alpha, float limit) {
// clamp gate: min=None, max=limit
const float gate_f = (float)gate;
const float clamped_gate = gate_f > limit ? limit : gate_f;
// clamp up: min=-limit, max=limit
const float up_f = (float)up;
const float clamped_up =
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
// glu = gate * sigmoid(gate * alpha)
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
const float glu = clamped_gate * sigmoid_val;
// (up + 1) * glu
return (T)((clamped_up + 1.0f) * glu);
// Clamp gate to (-inf, limit] and up to [-limit, limit]
const float g = fminf((float)gate, limit);
const float u = fmaxf(fminf((float)up, limit), -limit);
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
}
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
template <typename scalar_t,
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
const float)>
__global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved)
const int d, const float alpha, const float limit) {
// For interleaved data: input has 2*d elements per token (gate/up pairs)
// output has d elements per token
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load
const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores.
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
const scalar_t* in_ptr = input + token_idx * 2 * d;
scalar_t* out_ptr = out + token_idx * d;
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
// Check alignment for 128-bit vectorized access on input.
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
const bool in_aligned = is_16byte_aligned(in_ptr);
const bool out_aligned =
(reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0; // 8-byte for int2
if (in_aligned && out_aligned && d >= PAIRS) {
// Fast path: vectorized loop
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
// Each int2 store writes PAIRS output elements
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int2* out_vec = reinterpret_cast<int2*>(out_ptr);
const int num_vecs = d / PAIRS;
const int vec_end = num_vecs * PAIRS;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]);
int2 r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < PAIRS; j++) {
rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
}
}
}
@ -217,10 +324,41 @@ __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
const scalar_t* in_ptr = input + token_idx * d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]), r;
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(vp[j]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&in_ptr[idx]);
out_ptr[idx] = ACT_FN(x);
}
}
}

View File

@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) {
if constexpr (SF == SCORING_NONE) {
return val;
} else if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val);
} else {
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
"Unsupported ScoringFunc in apply_scoring");
return val;
}
}
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
float scale = routed_scaling_factor;
if (renormalize) {
scale /= topk_sum;
}
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
: (base * routed_scaling_factor);
float value = base * scale;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = value;
}

View File

@ -1,2 +1,3 @@
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu

View File

@ -10,6 +10,8 @@ import jinja2
ARCHS = []
SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
@ -157,6 +163,7 @@ def remove_old_kernels():
def generate_new_kernels():
result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
@ -174,6 +181,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
@ -197,78 +206,89 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"stages": 4,
"group_blocks": group_blocks,
"is_zp_float": "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
for result_dict_tmp in [result_dict, sm_75_result_dict]:
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
all_template_str_list = []
if not config_list:
continue
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
filename = filename.lower()
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (

View File

@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
@ -84,146 +85,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id>
@ -439,9 +300,20 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
@ -618,7 +490,22 @@ __global__ void Marlin(
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if constexpr (moe_block_size >= 16)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
if constexpr (moe_block_size >= 8)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
if constexpr (moe_block_size >= 4)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
if constexpr (moe_block_size >= 2)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
block_num_valid_tokens = local_count;
#else
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
#endif
if (lane_id == 0)
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
@ -1018,10 +905,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads.
@ -1545,11 +1428,13 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
}
}
}
@ -1583,10 +1468,12 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
if constexpr (group_blocks != -1) {
@ -2132,6 +2019,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];

View File

@ -142,7 +142,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
return tb_scales * stages;
}
}
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float, bool is_a_8bit) {
int is_zp_float, bool is_a_8bit, int stages) {
int pack_factor = 32 / num_bits;
// Get B size
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem, bool is_a_8bit) {
bool is_a_8bit, int stages, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size =
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem;
}
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_a_8bit) {
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem - 512)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
is_a_8bit, stages);
int group_blocks = 0;
if (!has_act_order) {
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue;
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg;
}
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, is_a_8bit),
is_a_8bit, stages, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,

View File

@ -1,2 +1,3 @@
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu

View File

@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.

View File

@ -10,6 +10,8 @@ import jinja2
ARCHS = []
SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
@ -166,6 +172,7 @@ def remove_old_kernels():
def generate_new_kernels():
result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
@ -184,6 +191,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
@ -207,78 +216,89 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"stages": 4,
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
for result_dict_tmp in [result_dict, sm_75_result_dict]:
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
all_template_str_list = []
if not config_list:
continue
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
filename = filename.lower()
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (

View File

@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
@ -148,7 +148,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
return tb_scales * stages;
}
}
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int has_zp, bool is_zp_float, bool is_a_8bit,
int stages) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem;
}
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem - 512)) {
is_zp_float, is_a_8bit, stages,
max_shared_mem - 512)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp,
is_zp_float, is_a_8bit, stages);
int group_blocks = 0;
if (!has_act_order) {
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue;
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
// int n_tiles = prob_n / th_config.thread_n;
// int k_tiles = prob_k / th_config.thread_k;
return {1, th_config};
}
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
bool is_a_8bit = a_type.size_bits() == 8;
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
sms);
thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n *
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) {
is_a_8bit, stages, max_shared_mem_new)) {
thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg};
}
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK(
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
prob_k, num_bits, group_size, has_act_order, is_k_full,
has_zp, is_zp_float, max_shared_mem_new),
has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem_new),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new);
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,

View File

@ -1,17 +1,19 @@
#pragma once
#include <torch/all.h>
#ifndef _marlin_cuh
#define _marlin_cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int32_t*>(smem_ptr)[0] =
reinterpret_cast<const int32_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int64_t*>(smem_ptr)[0] =
reinterpret_cast<const int64_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
__device__ inline void cp_async_fence() {}
template <int n>
__device__ inline void cp_async_wait() {}
#else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
#endif
} // namespace MARLIN_NAMESPACE_NAME
#endif

View File

@ -0,0 +1,269 @@
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
static_assert(!use_fp16_accum);
}
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
#else
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, half>::value &&
use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
#else
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]));
#endif
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}
}
}
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
static_assert(!use_fp16_accum);
}
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
#else
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, half>::value &&
use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
#else
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]));
#endif
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}
}
}
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -26,6 +26,7 @@
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "dequant.h"
#include "marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
@ -75,137 +76,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id>
@ -415,6 +285,17 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
const int4* A = A0;
@ -873,10 +754,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads.
@ -1395,11 +1272,13 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
}
}
}
@ -1433,10 +1312,12 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
if constexpr (group_blocks != -1) {
@ -1956,6 +1837,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];

View File

@ -550,8 +550,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
int rowEnd = rowEnds[rowIdx];
// Local pointers to this block
outIndices += rowIdx * topK;
logits += rowIdx * stride0;
outIndices += static_cast<int64_t>(rowIdx) * topK;
logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
@ -576,19 +576,21 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// Local pointers to this block
if constexpr (!multipleBlocksPerRow && !mergeBlocks) {
outIndices += rowIdx * topK;
outIndices += static_cast<int64_t>(rowIdx) * topK;
} else if constexpr (multipleBlocksPerRow) {
const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
outIndices +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
outLogits +=
static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
} else if constexpr (mergeBlocks) {
rowEnd = numBlocksToMerge * topK;
indices += rowIdx * numBlocksToMerge * topK;
outIndices += rowIdx * topK;
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
outIndices += static_cast<int64_t>(rowIdx) * topK;
}
logits += rowIdx * stride0;
logits += static_cast<int64_t>(rowIdx) * stride0;
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort,
multipleBlocksPerRow, mergeBlocks>(

View File

@ -621,7 +621,7 @@ ENV UV_HTTP_TIMEOUT=500
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
uv pip install --system -r /tmp/kv_connectors.txt; \
uv pip install --system -r /tmp/kv_connectors.txt || true; \
fi
ENV VLLM_USAGE_SOURCE production-docker-image

View File

@ -50,5 +50,5 @@ ijson # Required for mistral streaming tool parser
setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic == 0.71.0
model-hosting-container-standards >= 0.1.9, < 1.0.0
mcp
model-hosting-container-standards >= 0.1.10, < 1.0.0
mcp

View File

@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
# TODO: remove skip after we fix the fusion thoroughly
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
def test_rms_group_quant(
model_name: str,
model_kwargs: dict[str, Any],
@ -562,7 +564,9 @@ def test_rms_group_quant(
splitting_ops=splitting_ops,
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
pass_config=PassConfig(
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)

View File

@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
)
assert transcription.segments is not None
assert len(transcription.segments) > 0
@pytest.mark.asyncio
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len

View File

@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model):
)
out = json.loads(translation)["text"].strip().lower()
assert out.count("greek sea") == 2
@pytest.mark.asyncio
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
client, model_name = client_and_model
transcription = await client.audio.translations.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_name)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len

View File

@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
### Run tests with pytest (like buildkite)
```bash
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \
--tp-size=1
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt
```
### Run standalone evaluation script
@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
accuracy_threshold: 0.54 # Minimum expected accuracy
num_questions: 1319 # Number of questions (default: full test set)
num_fewshot: 5 # Few-shot examples from train set
max_model_len: 4096 # Model context length
server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments
env: # Environment variables (optional)
VLLM_USE_FLASHINFER_MOE_FP4: "1"
```
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
The `env` field accepts a dictionary of environment variables to set for the server process.

View File

@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
accuracy_threshold: 0.72
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
accuracy_threshold: 0.74
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
accuracy_threshold: 0.31
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
accuracy_threshold: 0.45
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
accuracy_threshold: 0.60
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
accuracy_threshold: 0.375
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
accuracy_threshold: 0.89
num_questions: 1319
num_fewshot: 5
max_model_len: 4096
server_args: "--enforce-eager --max-model-len 4096"

View File

@ -0,0 +1,12 @@
model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4"
accuracy_threshold: 0.75
num_questions: 1319
num_fewshot: 5
server_args: >-
--enforce-eager
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}'
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"

View File

@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-CT.yaml
DeepSeek-V2-Lite-Instruct-FP8.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml

View File

@ -11,14 +11,12 @@ def pytest_addoption(parser):
default="configs/models-small.txt",
help="File containing list of config files to test",
)
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
def pytest_generate_tests(metafunc):
"""Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file")
tp_size = metafunc.config.getoption("--tp-size")
# Handle both relative and absolute paths
config_list_path = Path(config_list_file)
@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
# Generate test parameters
if config_files:
metafunc.parametrize(
["config_filename", "tp_size"],
[(config_file, int(tp_size)) for config_file in config_files],
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files],
"config_filename",
config_files,
ids=[config_file.stem for config_file in config_files],
)
else:
print("No config files found, test will be skipped")

View File

@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Usage:
pytest -s -v test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \
--tp-size=1
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt
"""
import shlex
import yaml
from tests.utils import RemoteOpenAIServer
from .gsm8k_eval import evaluate_gsm8k
RTOL = 0.08 # Relative tolerance for accuracy comparison
TOL = 0.08 # Absolute tolerance for accuracy comparison
def launch_gsm8k_eval(eval_config, server_url, tp_size):
"""Launch GSM8K evaluation using our isolated script."""
def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
"""Run GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
if "://" in server_url:
server_url = server_url.split("://")[1]
host_port = server_url.split("/")[0] # Remove path if present
if ":" in host_port:
host, port = host_port.split(":")
port = int(port)
host, p = host_port.split(":")
port = int(p)
else:
host = host_port
port = 8000
@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
return results
def test_gsm8k_correctness_param(config_filename, tp_size):
def test_gsm8k_correctness(config_filename):
"""Test GSM8K correctness for a given model configuration."""
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Server arguments
server_args = [
"--max-model-len",
str(eval_config.get("max_model_len", 4096)),
"--enforce-eager",
"--trust-remote-code",
"--tensor-parallel-size",
str(tp_size),
]
# Parse server arguments from config (use shlex to handle quoted strings)
server_args_str = eval_config.get("server_args", "")
server_args = shlex.split(server_args_str) if server_args_str else []
# Add standard server arguments
server_args.extend(
[
"--trust-remote-code",
]
)
env_dict = eval_config.get("env", None)
print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}")
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
print(f"Server args: {' '.join(server_args)}")
# Launch server and run evaluation
with RemoteOpenAIServer(
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480
eval_config["model_name"],
server_args,
env_dict=env_dict,
max_wait_seconds=600,
) as remote_server:
server_url = remote_server.url_for("v1")
print(f"Server started at: {server_url}")
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
results = run_gsm8k_eval(eval_config, server_url)
# Check accuracy against threshold
measured_accuracy = results["accuracy"]
expected_accuracy = eval_config["accuracy_threshold"]
measured_metric = results["accuracy"]
expected_metric = eval_config["accuracy_threshold"]
print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Accuracy: {measured_accuracy:.3f}")
print(f" Expected: {expected_accuracy:.3f}")
print(f" Measured metric: {measured_metric:.4f}")
print(f" Expected metric: {expected_metric:.4f}")
print(f" Tolerance: {TOL:.4f}")
print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}")
# Verify accuracy is within tolerance
assert measured_accuracy >= expected_accuracy - RTOL, (
f"Accuracy too low: {measured_accuracy:.3f} < "
f"{expected_accuracy:.3f} - {RTOL:.3f}"
# Verify metric is within tolerance
assert measured_metric >= expected_metric - TOL, (
f"GSM8K metric too low: {measured_metric:.4f} < "
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
)
print(f"✅ GSM8K test passed for {eval_config['model_name']}")

View File

@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
total_num_patches.item() + num_tiles.item() + 3
) # image start, image, image end
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
profiled_tokens = profiler.get_mm_max_tokens(
max_model_len,
mm_counts=mm_counts,
)
assert total_tokens == profiled_tokens["image"]
assert total_num_patches == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]

View File

@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
import numpy as np
import pytest
import torch
from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
assert modality_idxs == expected_modality_idxs
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, 5),
(torch.tensor([True, True, True, True, True]), 5),
(torch.tensor([False, False, False, False, False]), 0),
(torch.tensor([True, False, True, False, True]), 3),
(torch.tensor([True]), 1),
],
)
def test_placeholder_range_get_num_embeds(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_num_embeds == expected
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, None),
(
torch.tensor([False, True, False, True, True]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
if expected is None:
assert pr.embeds_cumsum is None
return
assert torch.equal(pr.embeds_cumsum, expected)
# cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum
@pytest.mark.parametrize(
"is_embed,start_idx,end_idx,expected",
[
(None, 2, 4, (2, 4)),
(
torch.tensor([False, True, False, True, True]),
3,
5,
(1, 3),
),
(
torch.tensor([False, True, False, True, True]),
0,
2,
(0, 1),
),
(
torch.tensor([True, False, True, False]),
2,
2,
(1, 1),
),
],
)
def test_placeholder_range_get_embeds_indices_in_range(
is_embed, start_idx, end_idx, expected
):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
@pytest.mark.parametrize(
"offset,is_embed,expected",
[
(0, None, [(0, 4)]),
(
2,
torch.tensor([False, True, False, True, True]),
[(3, 3), (5, 6)],
),
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
(0, torch.tensor([False, False, False, False]), []),
],
)
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
assert pr.extract_embeds_range() == expected
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])

View File

@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
)
# Call the function
result = make_local_attention_virtual_batches(
result, _ = make_local_attention_virtual_batches(
attn_chunk_size, common_attn_metadata, block_size
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
@ -23,7 +24,7 @@ class MockRequest:
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
return self._token_counts[input_id]
@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
compute_budget = 10
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
def test_encoder_cache_with_is_embed_mask():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.zeros(100, dtype=torch.bool)
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
request = MockRequestWithMask("r1", ["img1"], [100])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=100)
manager.allocate(request, 0)
assert manager.num_free_slots == 92
assert "img1" in manager.cached
old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds
assert new_size == 8
savings_ratio = old_size / new_size
assert savings_ratio == 12.5
def test_encoder_cache_mask_based_retrieval():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.tensor(
[False, False, True, True, False, True, True, True, False, False]
)
request = MockRequestWithMask("r1", ["img1"], [10])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=50)
manager.allocate(request, 0)
assert request.mm_features[0].mm_position.get_num_embeds == 5
start_idx = 2
end_idx = 8
num_embeds_before = is_embed[:start_idx].sum().item()
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 5
start_idx = 0
end_idx = 5
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 2

View File

@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
# enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported

View File

@ -38,7 +38,7 @@ class MockRequest:
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self._token_counts)
return self._token_counts[input_id]

View File

@ -4,7 +4,7 @@ import functools
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size
)
return super().build(common_prefix_len, common_attn_metadata, fast_build)
metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
return metadata
def update_block_table(
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
):
blk_table = metadata.make_virtual_batches_block_table(blk_table)
return super().update_block_table(metadata, blk_table, slot_mapping)
attn_backend = subclass_attention_backend(
name_prefix=prefix,

View File

@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp):
self.attn_backend,
)
if self.is_flash_attn_backend:
assert self.flash_attn_varlen_func is not None
self._fa_version = get_flash_attn_version()
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp):
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
)
return output

View File

@ -16,6 +16,7 @@ import einops
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@ -27,11 +28,15 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int,
) -> torch.Tensor:
kwargs = {}
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
kwargs["fa_version"] = fa_version
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
@ -43,6 +48,7 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False,
**kwargs,
)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
@ -56,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int,
) -> torch.Tensor:
return torch.empty_like(q)
@ -75,9 +82,10 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
)
@ -89,6 +97,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, router_logits
def combine(
@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
tensors_to_gather = [hidden_states, router_logits]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
if extra_tensors is not None:
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
@ -68,7 +69,11 @@ class All2AllManagerBase:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def set_num_sms(self, num_sms: int):

View File

@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch(
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
return hidden_states, router_logits
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False

View File

@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
num_encoder_token = request.get_num_encoder_embeds(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token

View File

@ -1007,10 +1007,17 @@ class GroupCoordinator:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, router_logits

View File

@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty: float | None = 0.0
"""The presence penalty to use for sampling."""
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:transcription-sampling-params]
# Default sampling parameters for transcription requests.
@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.

View File

@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing):
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None:
default_max_tokens = self.model_config.max_model_len
else:
default_max_tokens = min(
self.model_config.max_model_len, request.max_completion_tokens
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)

View File

@ -207,7 +207,7 @@ if TYPE_CHECKING:
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None
@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# kv-cache memory usage and enable longer contexts)
# TODO(lucas): Remove this flag once latency regression is resolved.
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool(
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1"))
),
# Enables support for the "store" option in the OpenAI Responses API.
# When set to 1, vLLM's OpenAI server will retain the input and output

View File

@ -14,11 +14,6 @@ class LoRARequest(
"""
Request for a LoRA adapter.
Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""

View File

@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"implementation based on the prepare_finalize"
)
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
)
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module

View File

@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
is_flashinfer_supporting_global_sf,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import (
aux_stream,
@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp):
)
with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4FusedMoE,
)
post_quant_allgather = (
has_flashinfer_trtllm_fused_moe()
and self.quant_method is not None
and self.dp_size > 1
and self.use_ep
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
self, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
hidden_states_combined, router_logits, extra_tensors_combined = (
dispatch_res
)
hidden_states_combined = (
hidden_states_combined,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:

View File

@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:

View File

@ -626,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert layer.expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoEMethod."
)
assert self.moe_quant_config is not None
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
@ -644,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments
m=x.shape[0],

View File

@ -253,7 +253,7 @@ class Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:

View File

@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:

View File

@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def override_quantization_method(
@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2_blockscale_swizzled, requires_grad=False
)
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer
a1_gscale = layer.w13_input_scale_quant
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
is_sf_swizzled_layout=False,
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=layer.e_score_correction_bias,
)
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
hidden_states=x_routing,
router_logits=router_logits,
)

View File

@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
global_num_experts: int,
@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
torch.bfloat16
).view(torch.int16)
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(

View File

@ -169,10 +169,13 @@ class DeciLMDecoderLayer(nn.Module):
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if not self._is_no_op_ffn:
ffn_mult = block_config.ffn.ffn_mult
intermediate_size = _ffn_mult_to_intermediate_size(
ffn_mult, config.hidden_size
)
if hasattr(block_config.ffn, "ffn_mult"):
ffn_mult = block_config.ffn.ffn_mult
intermediate_size = _ffn_mult_to_intermediate_size(
ffn_mult, config.hidden_size
)
else:
intermediate_size = block_config.ffn.intermediate_size
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,

View File

@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
return num_video_soft_tokens
def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from functools import cached_property, partial
from itertools import accumulate
from typing import (
TYPE_CHECKING,
@ -169,11 +169,42 @@ class PlaceholderRange:
between `offset` and `offset + length` to assign embeddings to.
"""
def get_num_embeds(self) -> int:
@cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
if self.is_embed is None:
return None
return self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int:
if self.embeds_cumsum is None:
return self.length
return int(self.is_embed.sum().item())
return int(self.embeds_cumsum[-1])
def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int
) -> tuple[int, int]:
"""
Returns the starting and ending indices of the embeddings of encoder outputs
in the range of [start_idx, end_idx) in the placeholders.
For example, given:
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
the second and the third embeddings from the encoder output.
"""
if self.embeds_cumsum is None:
return start_idx, end_idx
embeds_start_idx = (
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
return embeds_start_idx, embeds_end_idx
def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.
@ -188,7 +219,7 @@ class PlaceholderRange:
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]
return [(self.offset, self.offset + self.length - 1)]
mask_i = self.is_embed.int()
starts = torch.nonzero(

View File

@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def _get_mm_max_tokens(
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None:
mm_counts = self.get_mm_limits()
@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
return self._get_mm_num_tokens(mm_inputs)

View File

@ -164,7 +164,7 @@ class MultiModalRegistry:
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)

View File

@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
)
@functools.cache
def has_flashinfer_trtllm_fused_moe() -> bool:
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
if not has_flashinfer_moe():
return False
required_functions = [
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
import copy
from dataclasses import dataclass
from typing import ClassVar
@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH
)
supports_update_block_table: bool = True
def __init__(
self,
@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
)
return attn_metadata
def update_block_table(
self,
metadata: FlashAttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> FlashAttentionMetadata:
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools
from dataclasses import dataclass
@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
supports_update_block_table: bool = True
def __init__(
self,
kv_cache_spec: AttentionSpec,
@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_p=num_computed_tokens_p,
)
return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata

View File

@ -4,6 +4,7 @@ import abc
import enum
import functools
from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass
from typing import (
TYPE_CHECKING,
@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod
def __init__(
@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
@ -603,7 +622,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> CommonAttentionMetadata:
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor
@ -715,9 +734,12 @@ def make_local_attention_virtual_batches(
# tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
virtual_batches, -1
)
# Save as a lambda so we can return this for update_block_table
make_block_table = lambda block_table: block_table[
batch_indices_torch, block_indices_torch
].view(virtual_batches, -1)
block_table_local = make_block_table(block_table)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
@ -736,7 +758,7 @@ def make_local_attention_virtual_batches(
causal=True,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
)
), make_block_table
def make_kv_sharing_fast_prefill_common_attn_metadata(

View File

@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
encoder embeddings from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
return True
@ -104,7 +110,7 @@ class EncoderCacheManager:
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
@ -121,9 +127,9 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be
encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked.
Returns:
@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_embeds += num_embeds_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
if num_embeds <= self.num_free_slots:
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
if num_embeds > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
while num_embeds > self.num_free_slots:
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots += num_free_embeds
return True
def allocate(self, request: Request, input_id: int) -> None:
@ -176,16 +182,16 @@ class EncoderCacheManager:
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens
assert self.num_freeable_slots >= num_encoder_tokens
assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens
self.num_freeable_slots -= num_encoder_tokens
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
@ -218,9 +224,9 @@ class EncoderCacheManager:
return
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*.
@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_encoder_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_encoder_embeds += num_embeds_to_schedule
# Enough free slots
return num_tokens <= self.num_free_slots
return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots -= num_encoder_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_encoder_embeds

View File

@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface):
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_tokens_to_restore = sum(
preempted_req.get_num_encoder_tokens(i)
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs
)
encoder_compute_budget += num_tokens_to_restore
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
):
num_new_tokens = start_pos - num_computed_tokens
break
if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget, num_tokens_to_schedule
request, i, encoder_compute_budget, num_embeds_to_schedule
):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel = max(0, num_computed_tokens - start_pos)
end_idx_rel = min(
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i)

View File

@ -209,10 +209,10 @@ class Request:
def get_finished_reason(self) -> FinishReason | None:
return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length
return num_tokens
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
return num_embeds
def record_event(
self,

View File

@ -170,9 +170,7 @@ from .utils import (
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)
if TYPE_CHECKING:
@ -1640,6 +1638,15 @@ class GPUModelRunner(
logits_indices
)
# Cache attention metadata builds across hybrid KV-cache groups
# The only thing that changes between different hybrid KV-cache groups when the
# same metadata builder and KVCacheSpec is the same is the block table, so we
# can cache the attention metadata builds and just update the block table using
# `builder.update_block_table` if the builder supports it.
cached_attn_metadata: dict[
tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
] = {}
def _build_attn_group_metadata(
kv_cache_gid: int,
attn_gid: int,
@ -1647,13 +1654,15 @@ class GPUModelRunner(
ubid: int | None = None,
) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
builder = attn_group.get_metadata_builder(ubid or 0)
cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder))
cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
assert ubid is None, "UBatching not supported with GDN yet"
@ -1668,12 +1677,23 @@ class GPUModelRunner(
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
elif (
cache_key in cached_attn_metadata
and builder.supports_update_block_table
):
attn_metadata_i = builder.update_block_table(
cached_attn_metadata[cache_key],
common_attn_metadata.block_table_tensor,
common_attn_metadata.slot_mapping,
)
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
if builder.supports_update_block_table:
cached_attn_metadata[cache_key] = attn_metadata_i
if ubid is None:
assert isinstance(attn_metadata, dict)
@ -2271,10 +2291,7 @@ class GPUModelRunner(
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
self.encoder_cache[mm_hash] = output
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
@ -2325,6 +2342,13 @@ class GPUModelRunner(
num_encoder_tokens,
)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
@ -2332,16 +2356,14 @@ class GPUModelRunner(
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
@ -4570,31 +4592,8 @@ class GPUModelRunner(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (max_tokens_for_modality, hidden_size) and scatter
# encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)
dummy_encoder_outputs = expanded_outputs
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
for i, output in enumerate(dummy_encoder_outputs):
self.encoder_cache[f"tmp_{i}"] = output
# Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states = self._dummy_run(

View File

@ -4,10 +4,12 @@ from collections import defaultdict
from dataclasses import dataclass, field
import torch
from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
)
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: torch.Tensor | None,
@ -226,6 +231,7 @@ def scatter_mm_placeholders(
return placeholders
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: torch.Tensor | None,