diff --git a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
index 5302f524a0ae4..8106f50f18f66 100644
--- a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
+++ b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
@@ -17,7 +17,17 @@ wait_for_server() {
}
MODEL="deepseek-ai/DeepSeek-V2-lite"
-BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+
+# Set BACKENDS based on platform
+if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
+ # ROCm platform
+ BACKENDS=("allgather_reducescatter")
+ # Disable MOE padding for ROCm since it is causing eplb to fail
+ export VLLM_ROCM_MOE_PADDING=0
+else
+ # Non-ROCm platform (CUDA/other)
+ BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
diff --git a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
similarity index 64%
rename from .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
rename to .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
index a5135299297e2..6a1bef275d047 100644
--- a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
+++ b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
@@ -1,10 +1,12 @@
#!/usr/bin/env bash
set -euxo pipefail
-# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
+# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] [DATA_PARALLEL_SIZE] [TENSOR_PARALLEL_SIZE]
THRESHOLD=${1:-0.8}
NUM_Q=${2:-1319}
PORT=${3:-8020}
+DATA_PARALLEL_SIZE=${4:-2}
+TENSOR_PARALLEL_SIZE=${5:-2}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
@@ -17,7 +19,16 @@ wait_for_server() {
}
MODEL="QWen/Qwen3-30B-A3B-FP8"
-BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+# Set BACKENDS based on platform
+if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
+ # ROCm platform
+ BACKENDS=("allgather_reducescatter")
+ # Disable MOE padding for ROCm since it is causing eplb to fail
+ export VLLM_ROCM_MOE_PADDING=0
+else
+ # Non-ROCm platform (CUDA/other)
+ BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
@@ -36,8 +47,10 @@ for BACK in "${BACKENDS[@]}"; do
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
- --tensor-parallel-size 2 \
- --data-parallel-size 2 \
+ --enable-eplb \
+ --eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
+ --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
+ --data-parallel-size ${DATA_PARALLEL_SIZE} \
--enable-expert-parallel \
--trust-remote-code \
--max-model-len 2048 \
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index 4e2ff5c5a6bd5..4ddf11c0b268f 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -754,6 +754,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
+ - vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
@@ -1319,7 +1320,10 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- - pytest -v -s -x lora/test_gptoss_tp.py
+
+ # Disabled for now because MXFP4 backend on non-cuda platform
+ # doesn't support LoRA yet
+ #- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min
@@ -1482,4 +1486,4 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index a5719d438eece..10a19c52c72dc 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -192,6 +192,7 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
@@ -691,6 +692,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
+ - vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
@@ -901,11 +903,12 @@ steps:
- label: Transformers Nightly Models Test
working_dir: "/vllm-workspace/"
optional: true
+ soft_fail: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)'
+ - pytest -v -s tests/models/test_initialization.py
- pytest -v -s tests/models/test_transformers.py
- # - pytest -v -s tests/models/multimodal/processing/
+ - pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
@@ -969,6 +972,7 @@ steps:
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
+ - vllm/model_executor/layers/fused_moe/layer.py
- tests/compile/test_fusion_attn.py
- tests/compile/test_silu_mul_quant_fusion.py
- tests/compile/distributed/test_fusion_all_reduce.py
@@ -1115,6 +1119,7 @@ steps:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
@@ -1339,11 +1344,20 @@ steps:
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010
-- label: Qwen3-30B-A3B-FP8-block Accuracy
+- label: Qwen3-30B-A3B-FP8-block Accuracy (H100)
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
+
+- label: Qwen3-30B-A3B-FP8-block Accuracy (B200)
+ timeout_in_minutes: 60
+ gpu: b200
+ optional: true
+ num_gpus: 2
+ working_dir: "/vllm-workspace"
+ commands:
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
\ No newline at end of file
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 0e834c057c401..3247408e1163e 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -9,6 +9,7 @@
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/model_loader @22quinn
+/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson
/vllm/lora @jeejeelee
@@ -59,6 +60,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/kv_connector/nixl_integration @NickLucche
/tests/v1/kv_connector @ApostaC
/tests/v1/offloading @ApostaC
+/tests/v1/determinism @yewentao256
# Transformers modeling backend
/vllm/model_executor/models/transformers @hmellor
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a4cf51d17e982..d88ba3aa66303 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -136,7 +136,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
+ Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
@@ -604,12 +604,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
- "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
+ "csrc/quantization/fp4/nvfp4_experts_quant.cu"
+ "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
+ "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
+ list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md
index 41e68e047be82..a28c6956be0e9 100644
--- a/benchmarks/kernels/deepgemm/README.md
+++ b/benchmarks/kernels/deepgemm/README.md
@@ -2,7 +2,7 @@
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
-Currently this just includes dense GEMMs and only works on Hopper GPUs.
+Currently, this just includes dense GEMMs and only works on Hopper GPUs.
## Setup
diff --git a/csrc/cache.h b/csrc/cache.h
index b162a4a2bc31f..f2a5ec0acf5cd 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt);
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 32960cc8073bb..8a5457206c706 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -905,91 +905,79 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
-template
+template
__global__ void gather_and_maybe_dequant_cache(
- const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
- // ENTRIES...]
- scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
- const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
- const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
- const int32_t block_size, const int32_t entry_size,
+ const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
+ // ENTRIES...]
+ scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
+ const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
+ const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
+ const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
+ const int32_t num_tokens, const int32_t block_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
+ constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
+ using ltype = vllm::vec_n_t;
+ using stype = vllm::vec_n_t;
+ // We are adding this for code readability which will be optimized out when
+ // build in release.
+ assert(CTA_SIZE == blockDim.x);
- const int64_t bid = blockIdx.x; // Batch ID
- const int32_t num_splits = gridDim.y;
- const int32_t split = blockIdx.y;
- const int32_t seq_start = cu_seq_lens[bid];
- const int32_t seq_end = cu_seq_lens[bid + 1];
- const int32_t seq_len = seq_end - seq_start;
- const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
- const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
+#pragma unroll
+ for (int token_id = blockIdx.x; token_id < num_tokens;
+ token_id += gridDim.x) {
+ int64_t batch_id = token_to_seq[token_id];
+ int64_t batch_start = cu_seq_lens[batch_id];
+ int64_t batch_end = cu_seq_lens[batch_id + 1];
+ int32_t batch_offset = token_id - batch_start;
- const int32_t split_start = split * split_blocks;
- const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
+ if (token_id >= batch_end) return;
+ int32_t offset = 0;
+ if (seq_starts != nullptr) {
+ offset = seq_starts[batch_id];
+ }
+ batch_offset += offset;
+ int32_t block_table_id = batch_offset / block_size;
+ int32_t slot_id = batch_offset % block_size;
+ int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
+ int32_t block_id = block_table[block_table_offset];
+ int64_t cache_offset =
+ block_id * cache_block_stride + slot_id * cache_entry_stride;
+ constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
+ scalar_t* dst_ = dst + token_id * dst_entry_stride;
+ cache_t* src_ = const_cast(src_cache) + cache_offset;
- const bool is_active_split = (split_start < tot_blocks);
- const bool is_last_split = (split_end == tot_blocks);
-
- if (!is_active_split) return;
-
- int32_t full_blocks_end = split_end;
- int32_t partial_block_size = 0;
-
- // Adjust the pointer for the block_table for this batch.
- // If seq_starts is provided, compute an offset based on (seq_starts[bid] /
- // page_size)
- const int32_t batch_offset = bid * block_table_stride;
- int32_t offset = 0;
- if (seq_starts != nullptr) {
- offset = seq_starts[bid] / block_size;
- }
- const int32_t* batch_block_table = block_table + batch_offset + offset;
-
- // Adjust dst pointer based on the cumulative sequence lengths.
- dst += seq_start * dst_entry_stride;
-
- if (is_last_split) {
- partial_block_size = seq_len % block_size;
- if (partial_block_size) full_blocks_end -= 1;
- }
-
- auto copy_entry = [&](const cache_t* __restrict__ _src,
- scalar_t* __restrict__ _dst) {
- for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
+#pragma unroll
+ for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
- _dst[i] = static_cast(_src[i]);
+ reinterpret_cast(dst_)[idx] =
+ static_cast(reinterpret_cast(src_)[idx]);
} else {
- _dst[i] =
- fp8::scaled_convert(_src[i], *scale);
+ ltype loaded_val = reinterpret_cast(src_)[idx];
+ stype store_val;
+#pragma unroll
+ for (int j = 0; j < vec_size; ++j) {
+ store_val.val[j] = fp8::scaled_convert(
+ loaded_val.val[j], *scale);
+ }
+ reinterpret_cast(dst_)[idx] = store_val;
}
}
- };
-
- const auto loop_end =
- std::min((int64_t)full_blocks_end, block_table_stride - offset);
- for (int pid = split_start; pid < loop_end; ++pid) {
- auto block_id = batch_block_table[pid];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
- for (int eid = 0; eid < block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
- }
- }
-
- if (partial_block_size) {
- if (offset + full_blocks_end < block_table_stride) {
- auto block_id = batch_block_table[full_blocks_end];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr =
- dst + full_blocks_end * block_size * dst_entry_stride;
- for (int eid = 0; eid < partial_block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
+ // process tail
+ constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
+ dst_ = dst_ + ENTRY_SIZE - tail_cnt;
+ src_ = src_ + ENTRY_SIZE - tail_cnt;
+#pragma unroll
+ for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+ dst_[idx] = static_cast(src_[idx]);
+ } else {
+ dst_[idx] =
+ fp8::scaled_convert(src_[idx], *scale);
}
}
}
@@ -1001,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
-#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
- vllm::gather_and_maybe_dequant_cache \
- <<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst.data_ptr()), \
- block_table.data_ptr(), cu_seq_lens.data_ptr(), \
- block_size, entry_size, block_table_stride, cache_block_stride, \
- cache_entry_stride, dst_entry_stride, \
- reinterpret_cast(scale.data_ptr()), seq_starts_ptr);
+#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
+ vllm::gather_and_maybe_dequant_cache \
+ <<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst.data_ptr()), \
+ block_table.data_ptr(), cu_seq_lens.data_ptr(), \
+ token_to_seq.data_ptr(), num_tokens, block_size, \
+ block_table_stride, cache_block_stride, cache_entry_stride, \
+ dst_entry_stride, reinterpret_cast(scale.data_ptr()), \
+ seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
+// - token_to_seq contains the back mapping from token_id to batch_id
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
- int32_t entry_size = src_cache.flatten(2, -1).size(2);
+ int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
@@ -1038,6 +1030,9 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
+ TORCH_CHECK(head_dim == 576,
+ "gather_and_maybe_dequant_cache only support the head_dim to 576 "
+ "for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
@@ -1055,10 +1050,9 @@ void gather_and_maybe_dequant_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
- // Decide on the number of splits based on the batch size.
- int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
- dim3 grid(batch_size, num_splits);
- dim3 block(1024);
+ constexpr int32_t thread_block_size = 64;
+ dim3 grid(num_tokens);
+ dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr;
diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp
index 50f17c758c148..92f8bee5a47a0 100644
--- a/csrc/cpu/cpu_attn.cpp
+++ b/csrc/cpu/cpu_attn.cpp
@@ -13,6 +13,18 @@
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
+#ifdef __aarch64__
+ #include "cpu_attn_neon.hpp"
+ #define NEON_DISPATCH(...) \
+ case cpu_attention::ISA::NEON: { \
+ using attn_impl = cpu_attention::AttentionImpl; \
+ return __VA_ARGS__(); \
+ }
+#else
+ #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
+#endif // #ifdef __aarch64__
+
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
@@ -41,6 +53,7 @@
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
+ NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl
class AttentionImpl {};
@@ -143,6 +143,12 @@ struct AttentionMetadata {
case ISA::VEC:
ss << "VEC, ";
break;
+ case ISA::VEC16:
+ ss << "VEC16, ";
+ break;
+ case ISA::NEON:
+ ss << "NEON, ";
+ break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
@@ -841,7 +847,7 @@ struct VecTypeTrait {
};
#endif
-#if !defined(__powerpc__)
+#if !defined(__powerpc__) && !defined(__s390x__)
template <>
struct VecTypeTrait {
using vec_t = vec_op::FP16Vec16;
diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp
new file mode 100644
index 0000000000000..827f0cfbc718e
--- /dev/null
+++ b/csrc/cpu/cpu_attn_neon.hpp
@@ -0,0 +1,386 @@
+#ifndef CPU_ATTN_NEON_HPP
+#define CPU_ATTN_NEON_HPP
+
+#include "cpu_attn_impl.hpp"
+#include
+#include
+namespace cpu_attention {
+
+namespace {
+
+#define BLOCK_SIZE_ALIGNMENT 32
+#define HEAD_SIZE_ALIGNMENT 32
+#define MAX_Q_HEAD_NUM_PER_ITER 16
+
+// These do not use vectorized class for loading / converting
+// because csrc/cpu/cpu_types_arm.hpp does not have fallback options
+// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that
+// doesn't support BF16.
+// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency.
+template
+FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0,
+ float32x4_t& b1);
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const float* p, float32x4_t& b0,
+ float32x4_t& b1) {
+ b0 = vld1q_f32(p + 0);
+ b1 = vld1q_f32(p + 4);
+}
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const c10::Half* p,
+ float32x4_t& b0,
+ float32x4_t& b1) {
+ const float16_t* h = reinterpret_cast(p);
+ float16x8_t v = vld1q_f16(h);
+ b0 = vcvt_f32_f16(vget_low_f16(v));
+ b1 = vcvt_f32_f16(vget_high_f16(v));
+}
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p,
+ float32x4_t& b0,
+ float32x4_t& b1) {
+ const uint16_t* u = reinterpret_cast(p);
+#ifdef ARM_BF16_SUPPORT
+ uint16x8_t u0 = vld1q_u16(u);
+ bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0);
+ b0 = vcvtq_low_f32_bf16(bf0);
+ b1 = vcvtq_high_f32_bf16(bf0);
+#else
+ uint16x8_t x0 = vld1q_u16(u);
+ uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16);
+ uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16);
+ b0 = vreinterpretq_f32_u32(lo);
+ b1 = vreinterpretq_f32_u32(hi);
+#endif
+}
+
+// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs
+// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
+// #FMLAs = (K // 4) * (4 * 2 * M)
+// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
+template
+FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4(
+ const float* __restrict A, // [M x K],
+ const kv_cache_t* __restrict B, // [K x 8],
+ float* __restrict C, // [M x 8],
+ int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
+ // kernel supports max M of 8, as it'd spill for larger M
+ static_assert(1 <= M && M <= 8, "M must be in [1,8]");
+
+// helpers for per-M codegen
+#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
+#define IF_M(i) if constexpr (M > (i))
+
+ // A row base pointers
+#define DECL_A(i) const float* a##i = A + (i) * lda;
+ ROWS_APPLY(DECL_A)
+#undef DECL_A
+
+ // declare 2 accumulators per row of M
+#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1;
+ ROWS_APPLY(DECL_ACC)
+#undef DECL_ACC
+
+ // initialize accumulators
+#define INIT_ACC(i) \
+ IF_M(i) { \
+ if (accumulate) { \
+ acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \
+ acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \
+ } else { \
+ acc##i##_0 = vdupq_n_f32(0.f); \
+ acc##i##_1 = vdupq_n_f32(0.f); \
+ } \
+ }
+ ROWS_APPLY(INIT_ACC)
+#undef INIT_ACC
+
+ int32_t k = 0;
+
+ // K unrolled by 4
+ for (; k + 3 < K; k += 4) {
+ // load A[k..k+3] for each active row (M)
+#define LOAD_A4(i) \
+ float32x4_t a##i##v; \
+ IF_M(i) a##i##v = vld1q_f32(a##i + k);
+ ROWS_APPLY(LOAD_A4)
+#undef LOAD_A4
+
+ // helper: FMA lane L from aiv
+#define FMAS_LANE(i, aiv, L) \
+ IF_M(i) { \
+ acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \
+ acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \
+ }
+
+ // k + 0
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb, b0, b1);
+#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
+ ROWS_APPLY(STEP_K0)
+#undef STEP_K0
+ }
+ // k + 1
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb, b0, b1);
+#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
+ ROWS_APPLY(STEP_K1)
+#undef STEP_K1
+ }
+ // k + 2
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb, b0, b1);
+#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
+ ROWS_APPLY(STEP_K2)
+#undef STEP_K2
+ }
+ // k + 3
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb, b0, b1);
+#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
+ ROWS_APPLY(STEP_K3)
+#undef STEP_K3
+ }
+#undef FMAS_LANE
+ }
+
+ // K tail
+ for (; k < K; ++k) {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)k * ldb, b0, b1);
+#define TAIL_ROW(i) \
+ IF_M(i) { \
+ float32x4_t ai = vdupq_n_f32(*(a##i + k)); \
+ acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \
+ acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \
+ }
+ ROWS_APPLY(TAIL_ROW)
+#undef TAIL_ROW
+ }
+
+ // store accumulators to C
+#define STORE_ROW(i) \
+ IF_M(i) { \
+ vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \
+ vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \
+ }
+ ROWS_APPLY(STORE_ROW)
+#undef STORE_ROW
+
+#undef ROWS_APPLY
+#undef IF_M
+}
+
+template
+FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A,
+ const kv_cache_t* __restrict B,
+ float* __restrict C, int32_t M,
+ int32_t K, int64_t lda,
+ int64_t ldb, int64_t ldc,
+ bool accumulate) {
+ // micro kernel is Mx8
+ static_assert(N % 8 == 0, "N must be a multiple of 8");
+ for (int32_t m = 0; m < M;) {
+ int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
+ const float* Ab = A + m * lda;
+ float* Cb = C + m * ldc;
+
+ for (int32_t n = 0; n < N; n += 8) {
+ const kv_cache_t* Bn = B + n;
+ float* Cn = Cb + n;
+ switch (mb) {
+ case 8:
+ gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ case 4:
+ gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ case 2:
+ gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ default:
+ gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ }
+ }
+ // no tail loop for N as it's guaranteed to be a multiple of 8
+ m += mb;
+ }
+}
+
+template
+class TileGemmNeonFMLA {
+ public:
+ template
+ FORCE_INLINE static void gemm(const int32_t m_size,
+ float* __restrict__ a_tile,
+ kv_cache_t* __restrict__ b_tile,
+ float* __restrict__ c_tile, const int64_t lda,
+ const int64_t ldb, const int64_t ldc,
+ const int32_t block_size,
+ const int32_t dynamic_k_size,
+ const bool accum_c) {
+ if constexpr (phase == AttentionGemmPhase::QK) {
+ gemm_macro_neon_fmla_Mx8_Ku4(
+ a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
+ } else {
+ gemm_macro_neon_fmla_Mx8_Ku4(
+ a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
+ accum_c);
+ }
+ }
+};
+
+} // namespace
+
+// this is similar to "ISA::VEC" at the moment
+template
+class AttentionImpl {
+ public:
+ using query_t = scalar_t;
+ using q_buffer_t = float;
+ using kv_cache_t = scalar_t;
+ using logits_buffer_t = float;
+ using partial_output_buffer_t = float;
+ using prob_buffer_t = float;
+
+ constexpr static int64_t BlockSizeAlignment =
+ BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases
+ constexpr static int64_t HeadDimAlignment =
+ HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase
+ constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
+ constexpr static int64_t HeadDim = head_dim;
+ constexpr static ISA ISAType = ISA::NEON;
+ constexpr static bool scale_on_logits = false; // apply scale on q_buffer
+
+ static_assert(HeadDim % HeadDimAlignment == 0);
+ // the gemm micro kernel is Mx8
+ static_assert(HeadDimAlignment % 8 == 0);
+ static_assert(BlockSizeAlignment % 8 == 0);
+
+ public:
+ template typename attention>
+ FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
+ attention> attention_iteration;
+ attention_iteration(CPU_ATTENTION_PARAMS);
+ }
+
+ // k_cache_token_group_stride: stride of K cache when move to next
+ // BlockSizeAlignment tokens in a block
+ constexpr static int64_t k_cache_token_group_stride(
+ const int32_t block_size) {
+ return BlockSizeAlignment; // layout of k_cache block is [head_dim,
+ // block_size], row-major
+ }
+
+ // v_cache_token_group_stride: stride of V cache when move to next
+ // BlockSizeAlignment tokens in a block
+ constexpr static int64_t v_cache_token_group_stride(
+ const int32_t block_size) {
+ return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
+ // head_dim], row-major
+ }
+
+ // v_cache_head_group_stride: stride of V cache when move to next
+ // HeadDimAlignment head dims in a block
+ constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
+ return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
+ // row-major
+ }
+
+ // Copy q to q_buffer and cast it to fp32
+ static void copy_q_heads_tile(
+ scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
+ float* __restrict__ q_buffer, const int32_t q_num,
+ const int32_t q_heads_per_kv, const int64_t q_num_stride,
+ const int64_t q_head_stride, float scale) {
+ static_assert(head_dim % 16 == 0);
+ constexpr int32_t unroll_size = head_dim / 16;
+ using load_vec_t = typename VecTypeTrait::vec_t;
+
+ vec_op::FP32Vec16 scale_vec(scale);
+ for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
+ for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
+ scalar_t* __restrict__ curr_q =
+ src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
+ float* __restrict__ curr_q_buffer =
+ q_buffer + q_num_idx * q_heads_per_kv * head_dim +
+ q_head_idx * head_dim;
+
+ vec_op::unroll_loop([&](int32_t i) {
+ load_vec_t vec(curr_q);
+ vec_op::FP32Vec16 fp32_vec(vec);
+ fp32_vec = fp32_vec * scale_vec;
+ fp32_vec.save(curr_q_buffer);
+
+ curr_q += 16;
+ curr_q_buffer += 16;
+ });
+ }
+ }
+ }
+
+ // reshape K as column-major and V as row-major
+ static void reshape_and_cache(
+ const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
+ scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
+ const int64_t* __restrict__ slot_mapping, const int64_t token_num,
+ const int64_t key_token_num_stride, const int64_t value_token_num_stride,
+ const int64_t head_num, const int64_t key_head_num_stride,
+ const int64_t value_head_num_stride, const int64_t num_blocks,
+ const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
+ const int64_t block_size, const int64_t block_size_stride) {
+#pragma omp parallel for collapse(2)
+ for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
+ for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
+ const int64_t pos = slot_mapping[token_idx];
+ if (pos < 0) {
+ // skip
+ continue;
+ }
+
+ const int64_t block_idx = pos / block_size;
+ const int64_t block_offset = pos % block_size;
+ {
+ // Write Key
+ const scalar_t* key_start_ptr = key +
+ token_idx * key_token_num_stride +
+ head_idx * key_head_num_stride;
+ scalar_t* key_cache_start_ptr =
+ key_cache + block_idx * num_blocks_stride +
+ head_idx * cache_head_num_stride + block_offset;
+
+#pragma GCC unroll 8
+ for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
+ key_cache_start_ptr[j] = key_start_ptr[i];
+ }
+ }
+ {
+ // Write Value
+ const scalar_t* value_start_ptr = value +
+ token_idx * value_token_num_stride +
+ head_idx * value_head_num_stride;
+ scalar_t* value_cache_start_ptr =
+ value_cache + block_idx * num_blocks_stride +
+ head_idx * cache_head_num_stride + block_offset * head_dim;
+ std::memcpy(value_cache_start_ptr, value_start_ptr,
+ sizeof(scalar_t) * head_dim);
+ }
+ }
+ }
+ }
+};
+} // namespace cpu_attention
+
+#endif // #ifndef CPU_ATTN_NEON_HPP
diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp
index 51bca37e699b9..9efd8b7ec14a4 100644
--- a/csrc/cpu/cpu_types_vxe.hpp
+++ b/csrc/cpu/cpu_types_vxe.hpp
@@ -4,6 +4,7 @@
#include
#include
+#include
#include
namespace vec_op {
@@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec {
}
explicit FP32Vec8(const BF16Vec8& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
float reduce_sum() const {
@@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec {
}
FP32Vec8 exp() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::exp(ar.values[0]);
- ret.val[0][1] = std::exp(ar.values[1]);
- ret.val[0][2] = std::exp(ar.values[2]);
- ret.val[0][3] = std::exp(ar.values[3]);
- ret.val[1][0] = std::exp(ar.values[4]);
- ret.val[1][1] = std::exp(ar.values[5]);
- ret.val[1][2] = std::exp(ar.values[6]);
- ret.val[1][3] = std::exp(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ f32x4x2_t out;
+
+ const __vector float log2e = vec_splats(1.44269504088896341f);
+ const __vector float one = vec_splats(1.0f);
+ const __vector float min_x = vec_splats(-87.3f);
+ const __vector float max_x = vec_splats(88.7f);
+
+ // 5th-degree minimax polynomial for 2^r (r in [0,1))
+ const __vector float c1 = vec_splats(0.6931471805599453f);
+ const __vector float c2 = vec_splats(0.240226506959101f);
+ const __vector float c3 = vec_splats(0.05550410866482158f);
+ const __vector float c4 = vec_splats(0.009618129107628477f);
+ const __vector float c5 = vec_splats(0.0013333558146428443f);
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+
+ x = vec_max(x, min_x);
+ x = vec_min(x, max_x);
+
+ __vector float y = vec_mul(x, log2e);
+
+ __vector float kf = vec_floor(y);
+ __vector float r = vec_sub(y, kf);
+
+ __vector signed int k = vec_signed(kf);
+ const __vector signed int min_k = vec_splats((signed int)-126);
+ const __vector signed int max_k = vec_splats((signed int)127);
+ k = vec_min(vec_max(k, min_k), max_k);
+
+ // Build 2^k from exponent bits
+ __vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
+ __vector unsigned int bits = (__vector unsigned int)exp_int;
+ bits = vec_sl(bits, vec_splats((unsigned int)23));
+ __vector float pow2k = (__vector float)bits;
+
+ // Improved minimax polynomial
+ __vector float poly = vec_madd(c5, r, c4);
+ poly = vec_madd(poly, r, c3);
+ poly = vec_madd(poly, r, c2);
+ poly = vec_madd(poly, r, c1);
+ poly = vec_madd(poly, r, one);
+
+ out.val[i] = vec_mul(pow2k, poly);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 tanh() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::tanh(ar.values[0]);
- ret.val[0][1] = std::tanh(ar.values[1]);
- ret.val[0][2] = std::tanh(ar.values[2]);
- ret.val[0][3] = std::tanh(ar.values[3]);
- ret.val[1][0] = std::tanh(ar.values[4]);
- ret.val[1][1] = std::tanh(ar.values[5]);
- ret.val[1][2] = std::tanh(ar.values[6]);
- ret.val[1][3] = std::tanh(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
+ const __vector float one = vec_splats(1.0f);
+ const __vector float two = vec_splats(2.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float sat =
+ vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
+
+ f32x4x2_t out;
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+ __vector float ax = vec_abs(x);
+
+ // sign(x): +1 or -1
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // saturation mask: |x| > sat
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // 2x
+ __vector float two_x = vec_mul(x, two);
+
+ // Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
+ f32x4x2_t tmp;
+ tmp.val[0] = two_x;
+ tmp.val[1] = two_x;
+ FP32Vec8 exp_2x_vec(tmp);
+
+ FP32Vec8 e2x = exp_2x_vec.exp();
+ __vector float e = e2x.reg.val[i];
+
+ // tanh(x) = (e - 1) / (e + 1)
+ __vector float num = vec_sub(e, one);
+ __vector float den = vec_add(e, one);
+
+ __vector float t = vec_div(num, den);
+
+ // For large |x|, clamp to sign(x)
+ out.val[i] = vec_sel(t, sign, saturated);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 er() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::erf(ar.values[0]);
- ret.val[0][1] = std::erf(ar.values[1]);
- ret.val[0][2] = std::erf(ar.values[2]);
- ret.val[0][3] = std::erf(ar.values[3]);
- ret.val[1][0] = std::erf(ar.values[4]);
- ret.val[1][1] = std::erf(ar.values[5]);
- ret.val[1][2] = std::erf(ar.values[6]);
- ret.val[1][3] = std::erf(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // A&S 7.1.26 approximation:
+ // erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
+ // exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
+
+ const __vector float one = vec_splats(1.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float p = vec_splats(0.3275911f);
+
+ // Polynomial coeffs
+ const __vector float a1 = vec_splats(0.254829592f);
+ const __vector float a2 = vec_splats(-0.284496736f);
+ const __vector float a3 = vec_splats(1.421413741f);
+ const __vector float a4 = vec_splats(-1.453152027f);
+ const __vector float a5 = vec_splats(1.061405429f);
+
+ // Threshold where erf(x) ~ sign(x)
+ const __vector float sat = vec_splats(6.0f);
+
+ f32x4x2_t out;
+
+ for (int lane = 0; lane < 2; lane++) {
+ __vector float x = reg.val[lane];
+ __vector float ax = vec_abs(x);
+
+ // sign(x)
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // |x| > 6 → erf(x) = ±1
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // t = 1 / (1 + p * |x|)
+ __vector float t = vec_madd(p, ax, one);
+ t = vec_div(one, t);
+
+ // poly = a5
+ __vector float poly = a5;
+ poly = vec_madd(poly, t, a4);
+ poly = vec_madd(poly, t, a3);
+ poly = vec_madd(poly, t, a2);
+ poly = vec_madd(poly, t, a1);
+
+ // full polynomial: poly = poly * t
+ poly = vec_mul(poly, t);
+
+ // Compute exp(-x^2)
+ __vector float x2 = vec_mul(x, x);
+ __vector float neg_x2 = vec_neg(x2);
+
+ f32x4x2_t tmp;
+ tmp.val[0] = neg_x2;
+ tmp.val[1] = neg_x2;
+ FP32Vec8 exp_neg_x2(tmp);
+
+ FP32Vec8 e = exp_neg_x2.exp();
+ __vector float ex = e.reg.val[lane];
+
+ // erf(x) = sign * (1 - poly * exp(-x^2))
+ __vector float term = vec_mul(poly, ex);
+ __vector float y = vec_sub(one, term);
+ y = vec_mul(y, sign);
+
+ // saturated → ±1
+ __vector float sat_val = vec_mul(sign, one);
+ out.val[lane] = vec_sel(y, sat_val, saturated);
+ }
+
+ return FP32Vec8(out);
+ }
+ // Elementwise sigmoid(x) = 1 / (1 + exp(-x))
+ FP32Vec8 sigmoid() const {
+ const __vector float one = vec_splats(1.0f);
+
+ f32x4x2_t neg;
+ for (int i = 0; i < 2; ++i) {
+ neg.val[i] = vec_neg(reg.val[i]);
+ }
+
+ FP32Vec8 neg_x(neg);
+ FP32Vec8 e = neg_x.exp(); // exp(-x)
+
+ f32x4x2_t denom;
+ for (int i = 0; i < 2; ++i) {
+ denom.val[i] = vec_add(one, e.reg.val[i]);
+ }
+
+ FP32Vec8 denom_vec(denom);
+ FP32Vec8 one_vec(1.0f);
+
+ return one_vec / denom_vec;
+ }
+
+ // Tanh-based GELU:
+ // gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
+ FP32Vec8 gelu_tanh() const {
+ const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
+ const __vector float k_0_0447 = vec_splats(0.044715f);
+
+ f32x4x2_t x2, x3, inner;
+ for (int i = 0; i < 2; ++i) {
+ __vector float x = reg.val[i];
+ x2.val[i] = vec_mul(x, x); // x^2
+ x3.val[i] = vec_mul(x2.val[i], x); // x^3
+ __vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
+ inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
+ }
+
+ FP32Vec8 inner_vec(inner);
+ FP32Vec8 t = inner_vec.tanh(); // tanh part
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ FP32Vec8 x_vec(*this);
+ return x_vec * half_vec * (one_vec + t);
+ }
+
+ // Erf-based GELU:
+ // gelu(x) = 0.5 * x * (1 + erf(x / √2))
+ FP32Vec8 gelu_erf() const {
+ const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
+ FP32Vec8 x_vec(*this);
+
+ f32x4x2_t scaled;
+ for (int i = 0; i < 2; ++i) {
+ scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
+ }
+ FP32Vec8 x_scaled(scaled);
+
+ FP32Vec8 erf_x = x_scaled.er();
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ return x_vec * half_vec * (one_vec + erf_x);
+ }
+
+ // Elementwise reciprocal: 1/x (scalar per lane, for correctness)
+ FP32Vec8 rcp() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / in.values[i];
+ }
+ return FP32Vec8(out.reg);
+ }
+
+ // Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
+ FP32Vec8 rsqrt() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / std::sqrt(in.values[i]);
+ }
+ return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
@@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec {
}
explicit FP32Vec16(const BF16Vec16& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
- reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
- reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
+ reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
+ reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
@@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec {
return result;
}
+ FP32Vec16 max(const FP32Vec16& b) const {
+ return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
+ vec_max(reg.val[1], b.reg.val[1]),
+ vec_max(reg.val[2], b.reg.val[2]),
+ vec_max(reg.val[3], b.reg.val[3])}));
+ }
+
+ float reduce_max() const {
+ AliasReg ar;
+ ar.reg = reg;
+ float result = ar.values[0];
+ unroll_loop([&result, &ar](int i) {
+ if (ar.values[i] > result) result = ar.values[i];
+ });
+ return result;
+ }
+
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
@@ -402,15 +628,14 @@ struct VecType {
using vec_type = BF16Vec8;
};
+// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
+using FP16Vec16 = FP32Vec16;
+
template
void storeFP32(float v, T* ptr) {
*ptr = v;
}
-inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
- acc = acc + a * b;
-}
-
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
@@ -429,6 +654,79 @@ inline void storeFP32(float v, c10::BFloat16* ptr) {
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
+// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
+// intrinsics
+
+// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
+FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_madd(a.reg, b.reg, acc.reg);
+}
+
+// FP32Vec8 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+// FP32Vec16 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Multiply-Subtract: acc = acc - (a * b)
+FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_msub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Add: acc = -(a * b) + acc
+FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Subtract: acc = -(a * b) - acc
+FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
@@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
@@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ __vector unsigned int lsb2 = inp2 >> sh16;
+ __vector unsigned int lsb3 = inp3 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ lsb2 = lsb2 & one;
+ lsb3 = lsb3 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ __vector unsigned int rnd2 = lsb2 + bias;
+ __vector unsigned int rnd3 = lsb3 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
+ inp2 = inp2 + rnd2;
+ inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
@@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
- inp2 = vec_sel(inp2, nan, sel2) >> sh16;
- inp3 = vec_sel(inp3, nan, sel3) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp2 = vec_sel(inp2, nan, sel2);
+ inp3 = vec_sel(inp3, nan, sel3);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+ inp2 = inp2 >> sh16;
+ inp3 = inp3 >> sh16;
+
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
-inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
+// 1D softmax over `n` elements in `input`, writes result to `output`.
+// Uses FP32Vec8 for main body, scalar tail handling.
+// Requirement: n > 0
+FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: find max ----------
+ float max_val = -std::numeric_limits::infinity();
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 v(input + i);
+ FP32Vec8::AliasReg ar;
+ ar.reg = v.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ if (ar.values[j] > max_val) max_val = ar.values[j];
+ }
+ }
+ for (; i < n; ++i) {
+ if (input[i] > max_val) max_val = input[i];
+ }
+
+ // ---------- Pass 2: compute exp(x - max) and sum ----------
+ float sum = 0.0f;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = input[i + j] - max_val;
+ }
+
+ FP32Vec8 v(tmp);
+ FP32Vec8 e = v.exp();
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = e.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ output[i + j] = ar.values[j];
+ sum += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float x = input[i] - max_val;
+ float ex = std::exp(x); // scalar tail
+ output[i] = ex;
+ sum += ex;
+ }
+
+ // ---------- Pass 3: normalize ----------
+ float inv_sum = 1.0f / sum;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = output[i + j] * inv_sum;
+ }
+ FP32Vec8 v(tmp);
+ v.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] *= inv_sum;
+ }
+}
+
+// 1D RMSNorm kernel:
+// input: x[0..n-1]
+// weight: w[0..n-1] (gamma), may be nullptr
+// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
+// eps: small epsilon for numerical stability
+FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
+ const float* weight, int n, float eps) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: compute sum of squares ----------
+ float sum_sq = 0.0f;
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ FP32Vec8 sq = x_vec * x_vec;
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = sq.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ sum_sq += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float v = input[i];
+ sum_sq += v * v;
+ }
+
+ float mean_sq = sum_sq / static_cast(n);
+ float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
+
+ // ---------- Pass 2: scale (and apply weight if given) ----------
+ const float inv_rms_f = inv_rms;
+ i = 0;
+
+ if (weight) {
+ // with gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ float wtmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ wtmp[j] = weight[i + j];
+ }
+ FP32Vec8 w_vec(wtmp);
+
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec * w_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f * weight[i];
+ }
+ } else {
+ // without gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f;
+ }
+ }
+}
+
+// Prefetch data to cache for better memory access performance
+FORCE_INLINE void prefetch(const void* addr) {
+ __builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
+}
}; // namespace vec_op
diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu
index 938bd4ab7fc62..9853fc942bab7 100644
--- a/csrc/cuda_view.cu
+++ b/csrc/cuda_view.cu
@@ -22,15 +22,10 @@ torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
auto strides = cpu_tensor.strides();
auto options = cpu_tensor.options().device(torch::kCUDA);
- // from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
- // const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
- // memory, so we don't free it here.
- auto deleter = [](void*) {
- // no-op, since the memory is owned by the original CPU tensor
- };
-
+ // use default no-op deleter, since the memory is owned by the original CPU
+ // tensor
torch::Tensor cuda_tensor =
- torch::from_blob(device_ptr, sizes, strides, deleter, options);
+ torch::from_blob(device_ptr, sizes, strides, options);
TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
index 5b007e5ea3283..6744402783832 100644
--- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
@@ -22,6 +22,7 @@
#include
#include
#include
+#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
@@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
}
template
-void run_fp4_blockwise_scaled_group_mm(
+void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
@@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm(
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
- "Failed to implement GEMM");
+ "Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
- TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
+ TORCH_CHECK(status == cutlass::Status::kSuccess,
+ "Failed to initialize GEMM: status=", (int)status,
+ " workspace_size=", workspace_size, " num_experts=", num_experts,
+ " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
+void run_fp4_blockwise_scaled_group_mm_sm120(
+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
+ int N, int K) {
+ using ProblemShape =
+ cutlass::gemm::GroupProblemShape>;
+ using ElementType = cutlass::float_e2m1_t;
+ using ElementSFType = cutlass::float_ue4m3_t;
+ using ElementA = cutlass::nv_float4_t;
+ using ElementB = cutlass::nv_float4_t;
+
+ // NOTE: For SM120 it seems templating the output type is not supported and
+ // we need to hardcode the output type to bfloat16
+ using ElementC = cutlass::bfloat16_t;
+ using ElementD = ElementC;
+ using ElementAccumulator = float;
+ // Layout definitions
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ using LayoutC = cutlass::layout::RowMajor;
+ using LayoutD = LayoutC;
+
+ // Alignment constraints
+ static constexpr int AlignmentA = 32;
+ static constexpr int AlignmentB = 32;
+ static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value;
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value;
+
+ // Architecture definitions
+ using ArchTag = cutlass::arch::Sm120;
+ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
+
+ using ClusterShape = Shape<_1, _1, _1>;
+ using MmaTileShape = Shape<_128, _128, _128>;
+
+ using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
+ ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
+
+ using CollectiveEpilogue =
+ typename cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, MmaTileShape, ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
+ ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
+ LayoutD*, AlignmentD,
+ cutlass::epilogue::collective::EpilogueScheduleAuto,
+ FusionOperation>::CollectiveOp;
+
+ using CollectiveMainloop =
+ typename cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
+ LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+ using StrideA = typename Gemm::GemmKernel::InternalStrideA;
+ using StrideB = typename Gemm::GemmKernel::InternalStrideB;
+ using StrideC = typename Gemm::GemmKernel::InternalStrideC;
+ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
+
+ using LayoutSFA =
+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
+ using LayoutSFB =
+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
+ using ScaleConfig =
+ typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
+
+ using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
+ int num_experts = static_cast(expert_offsets.size(0));
+ auto options_int =
+ torch::TensorOptions().dtype(torch::kInt64).device(a.device());
+
+ torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
+ torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
+ torch::Tensor c_strides1 =
+ torch::full({num_experts}, output.stride(0), options_int);
+ torch::Tensor a_strides1 =
+ torch::full({num_experts}, a.stride(0) * 2, options_int);
+ torch::Tensor b_strides1 =
+ torch::full({num_experts}, b.stride(1) * 2, options_int);
+
+ run_get_group_gemm_starts(
+ a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
+ layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
+ expert_offsets, sf_offsets, problem_sizes, M, N, K);
+
+ // Create an instance of the GEMM
+ Gemm gemm_op;
+
+ // Initialize problem_sizes_as_shapes correctly
+ UnderlyingProblemShape* problem_sizes_as_shapes =
+ static_cast(problem_sizes.data_ptr());
+
+ // Set the Scheduler info
+ cutlass::KernelHardwareInfo hw_info;
+ using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
+ typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
+ scheduler.raster_order = RasterOrderOptions::AlongM;
+ hw_info.device_id = a.get_device();
+ static std::unordered_map cached_sm_counts;
+ if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
+ cached_sm_counts[hw_info.device_id] =
+ cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
+ hw_info.device_id);
+ }
+ hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
+
+ // Mainloop Arguments
+ typename GemmKernel::MainloopArguments mainloop_args{
+ static_cast(a_ptrs.data_ptr()),
+ static_cast(a_strides1.data_ptr()),
+ static_cast(b_ptrs.data_ptr()),
+ static_cast(b_strides1.data_ptr()),
+ static_cast(a_scales_ptrs.data_ptr()),
+ reinterpret_cast(layout_sfa.data_ptr()),
+ static_cast(b_scales_ptrs.data_ptr()),
+ reinterpret_cast(layout_sfb.data_ptr())};
+
+ // Epilogue Arguments
+ typename GemmKernel::EpilogueArguments epilogue_args{
+ {}, // epilogue.thread
+ nullptr,
+ static_cast(c_strides1.data_ptr()),
+ static_cast(out_ptrs.data_ptr()),
+ static_cast(c_strides1.data_ptr())};
+ auto& fusion_args = epilogue_args.thread;
+ fusion_args.alpha_ptr_array =
+ reinterpret_cast(alpha_ptrs.data_ptr());
+ fusion_args.dAlpha = {_0{}, _0{}, 1};
+ fusion_args.beta = 0.0f;
+
+ // Gemm Arguments
+ typename GemmKernel::Arguments args{
+ cutlass::gemm::GemmUniversalMode::kGrouped,
+ {num_experts, problem_sizes_as_shapes, nullptr},
+ mainloop_args,
+ epilogue_args,
+ hw_info,
+ scheduler};
+
+ size_t workspace_size = Gemm::get_workspace_size(args);
+ auto const workspace_options =
+ torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
+ auto workspace = torch::empty(workspace_size, workspace_options);
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
+
+ auto can_implement_status = gemm_op.can_implement(args);
+ TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
+ "Failed to implement GEMM: status=", (int)can_implement_status);
+
+ // Run the GEMM
+ auto status = gemm_op.initialize(args, workspace.data_ptr());
+ TORCH_CHECK(status == cutlass::Status::kSuccess,
+ "Failed to initialize GEMM: status=", (int)status,
+ " workspace_size=", workspace_size, " num_experts=", num_experts,
+ " M=", M, " N=", N, " K=", K);
+
+ status = gemm_op.run(args, workspace.data_ptr(), stream);
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
+}
+
+template
+void run_fp4_blockwise_scaled_group_mm(
+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
+ int N, int K) {
+ int32_t version_num = get_sm_version_num();
+#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
+ if (version_num >= 120 && version_num < 130) {
+ run_fp4_blockwise_scaled_group_mm_sm120(
+ output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
+ expert_offsets, sf_offsets, M, N, K);
+ return;
+ }
+#endif
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
+ if (version_num >= 100 && version_num < 120) {
+ run_fp4_blockwise_scaled_group_mm_sm100(
+ output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
+ expert_offsets, sf_offsets, M, N, K);
+ return;
+ }
+#endif
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
+ version_num, ". Required capability: 100 or 120");
+}
+
+#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
+ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif
@@ -374,7 +583,8 @@ void cutlass_fp4_group_mm(
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
-#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
+#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
+ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
@@ -408,6 +618,14 @@ void cutlass_fp4_group_mm(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
} else {
+ #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
+ int32_t version_num = get_sm_version_num();
+ if (version_num >= 120 && version_num < 130) {
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
+ output.scalar_type());
+ }
+ #endif
run_fp4_blockwise_scaled_group_mm(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
@@ -416,8 +634,8 @@ void cutlass_fp4_group_mm(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
- "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
- "12.8 or above.");
+ "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
+ "and CUDA 12.8 or above.");
#endif
}
diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu
index 6d385e0dd94e7..82c53c2375a31 100644
--- a/csrc/quantization/fp4/nvfp4_experts_quant.cu
+++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu
@@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
-void scaled_fp4_experts_quant_sm100a(
+void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu
index c2b39e5438805..fb6d22f035b99 100644
--- a/csrc/quantization/fp4/nvfp4_quant_entry.cu
+++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu
@@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input_sf);
#endif
-#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
-void scaled_fp4_experts_quant_sm100a(
+#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
+ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
+void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
@@ -54,8 +55,9 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
-#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
- return scaled_fp4_experts_quant_sm100a(
+#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
+ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
+ return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
index 1001af05ff003..c5012a8669317 100644
--- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
@@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional const& bias);
#endif
-#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
- defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
- defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
+#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
-#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
@@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
- version_num, ". Required capability: 90 or 100");
+ version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes(
@@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes(
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional& blockscale_offsets) {
int32_t version_num = get_sm_version_num();
-#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k,
blockscale_offsets);
@@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes(
false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ",
- version_num, ". Required capability: 90 or 100");
+ version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
@@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
-#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
@@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
- version_num, ". Required capability: 90 or 100");
+ version_num, ". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 5af74c2c2a6b0..14913bef13125 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -695,7 +695,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
- " int batch_size, "
+ " Tensor token_to_seq, "
+ " int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 709b79e84fbbc..84a1802dbe03a 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -20,8 +20,8 @@ ARG PYTHON_VERSION=3.12
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
-# TODO: Restore to base image after FlashInfer AOT wheel fixed
-ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
+# Using cuda base image with minimal dependencies necessary for JIT compilation (FlashInfer, DeepGEMM, EP kernels)
+ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04
# By parameterizing the Deadsnakes repository URL, we allow third-party to use
# their own mirror. When doing so, we don't benefit from the transparent
@@ -85,7 +85,20 @@ ARG GET_PIP_URL
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
- && apt-get install -y ccache software-properties-common git curl sudo python3-pip \
+ && apt-get install -y --no-install-recommends \
+ ccache \
+ software-properties-common \
+ git \
+ curl \
+ sudo \
+ python3-pip \
+ libibverbs-dev \
+ # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
+ # as it was causing spam when compiling the CUTLASS kernels
+ gcc-10 \
+ g++-10 \
+ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 \
+ && rm -rf /var/lib/apt/lists/* \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
@@ -110,10 +123,6 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
-# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
-# as it was causing spam when compiling the CUTLASS kernels
-RUN apt-get install -y gcc-10 g++-10
-RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
RUN </dev/null 2>&1; then \
+ uv pip install --system /tmp/deepgemm/dist/*.whl; \
+ else \
+ echo "No DeepGEMM wheels to install; skipping."; \
+ fi'
+
+# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH (https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/.ci/manywheel/build_cuda.sh#L141C14-L141C36)
+ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+
+# Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage
+RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm-workspace/ep_kernels/dist \
+ --mount=type=cache,target=/root/.cache/uv \
+ uv pip install --system ep_kernels/dist/*.whl --verbose \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
-# Install DeepGEMM from source
-ARG DEEPGEMM_GIT_REF
-COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
-RUN --mount=type=cache,target=/root/.cache/uv \
- VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"}
-
-COPY tools/install_gdrcopy.sh install_gdrcopy.sh
-RUN set -eux; \
+RUN --mount=type=bind,source=tools/install_gdrcopy.sh,target=/tmp/install_gdrcopy.sh,ro \
+ set -eux; \
case "${TARGETPLATFORM}" in \
linux/arm64) UUARCH="aarch64" ;; \
linux/amd64) UUARCH="x64" ;; \
*) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \
esac; \
- ./install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"; \
- rm ./install_gdrcopy.sh
-
-# Install EP kernels(pplx-kernels and DeepEP)
-COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
-ENV CUDA_HOME=/usr/local/cuda
-RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a 10.0a+PTX}" \
- && bash install_python_libraries.sh
+ /tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"
# CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will
# return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers
@@ -415,6 +460,11 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
+RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
+ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
+ && apt-get update -y \
+ && apt-get install -y git
+
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
@@ -455,12 +505,11 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
-COPY requirements/kv_connectors.txt requirements/kv_connectors.txt
-
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/uv \
+ --mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
- uv pip install --system -r requirements/kv_connectors.txt; \
+ uv pip install --system -r /tmp/kv_connectors.txt; \
fi; \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
BITSANDBYTES_VERSION="0.42.0"; \
diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch
index b88b9c4992200..d663c82c3885e 100644
--- a/docker/Dockerfile.nightly_torch
+++ b/docker/Dockerfile.nightly_torch
@@ -76,34 +76,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
-# must put before installing xformers, so it can install the correct version of xfomrers.
-ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
-ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
-
-# Build xformers with cuda and torch nightly
-# following official xformers guidance: https://github.com/facebookresearch/xformers#build
-# todo(elainewy): cache xformers build result for faster build
-ARG max_jobs=16
-ENV MAX_JOBS=${max_jobs}
-ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
-
-ENV CCACHE_DIR=/root/.cache/ccache
-RUN --mount=type=cache,target=/root/.cache/ccache \
- --mount=type=cache,target=/root/.cache/uv \
- echo 'git clone xformers...' \
- && git clone https://github.com/facebookresearch/xformers.git --recursive \
- && cd xformers \
- && git checkout ${XFORMERS_COMMIT} \
- && git submodule update --init --recursive \
- && echo 'finish git clone xformers...' \
- && rm -rf build \
- && python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
- && cd .. \
- && rm -rf xformers
-
-RUN --mount=type=cache,target=/root/.cache/uv \
- uv pip install --system xformers-dist/*.whl --verbose
-
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
@@ -233,11 +205,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system vllm-dist/*.whl --verbose
-# install xformers again for the new environment
-RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
- --mount=type=cache,target=/root/.cache/uv \
- uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
-
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
# install package for build flashinfer
@@ -307,7 +274,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
# Logging to confirm the torch versions
-RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
+RUN pip freeze | grep -E 'torch|vllm|flashinfer'
# Logging to confirm all the packages are installed
RUN pip freeze
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 5d5b82c4fa5af..adac43c6accbe 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -1,4 +1,4 @@
-FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base
+FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
@@ -25,10 +25,14 @@ RUN apt clean && apt-get update -y && \
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1
-RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing
+RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc
+
+# This oneccl contains the BMG support which is not the case for default version of oneapi 2025.2.
+RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.6/intel-oneccl-2021.15.6.9_offline.sh
+RUN bash intel-oneccl-2021.15.6.9_offline.sh -a --silent --eula accept && \
+ echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc && \
+ echo "source /opt/intel/oneapi/ccl/2021.15/env/vars.sh --force" >> /root/.bashrc
-RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh
-RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc
SHELL ["bash", "-c"]
CMD ["bash", "-c", "source /root/.bashrc && exec bash"]
@@ -72,6 +76,7 @@ RUN python3 -m pip install -e tests/vllm_test_utils
ENV NIXL_VERSION=0.7.0
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
+# remove torch bundled oneccl to avoid conflicts
RUN --mount=type=cache,target=/root/.cache/pip \
pip uninstall oneccl oneccl-devel -y
diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png
index f8c104ba14259..b327eb2151f50 100644
Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ
diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md
index 5ce43c7984057..0aa89a89eae5c 100644
--- a/docs/configuration/conserving_memory.md
+++ b/docs/configuration/conserving_memory.md
@@ -49,9 +49,6 @@ llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.
-!!! warning
- CUDA graph capture takes up more memory in V1 than in V0.
-
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
??? code
diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md
index b0d390d7e1cbb..fdd9c317b022f 100644
--- a/docs/configuration/optimization.md
+++ b/docs/configuration/optimization.md
@@ -31,9 +31,7 @@ In vLLM V1, the default preemption mode is `RECOMPUTE` rather than `SWAP`, as re
Chunked prefill allows vLLM to process large prefills in smaller chunks and batch them together with decode requests. This feature helps improve both throughput and latency by better balancing compute-bound (prefill) and memory-bound (decode) operations.
-In vLLM V1, **chunked prefill is always enabled by default**. This is different from vLLM V0, where it was conditionally enabled based on model characteristics.
-
-With chunked prefill enabled, the scheduling policy prioritizes decode requests. It batches all pending decode requests before scheduling any prefill operations. When there are available tokens in the `max_num_batched_tokens` budget, it schedules pending prefills. If a pending prefill request cannot fit into `max_num_batched_tokens`, it automatically chunks it.
+In V1, **chunked prefill is enabled by default whenever possible**. With chunked prefill enabled, the scheduling policy prioritizes decode requests. It batches all pending decode requests before scheduling any prefill operations. When there are available tokens in the `max_num_batched_tokens` budget, it schedules pending prefills. If a pending prefill request cannot fit into `max_num_batched_tokens`, it automatically chunks it.
This policy has two benefits:
diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md
index 09fd85a466eed..735bb2e205332 100644
--- a/docs/contributing/ci/update_pytorch_version.md
+++ b/docs/contributing/ci/update_pytorch_version.md
@@ -98,21 +98,6 @@ to warm it up so that future builds are faster.
-## Update dependencies
-
-Several vLLM dependencies like xFormers depend on PyTorch and need
-to be updated accordingly. Rather than waiting for all of them to publish new
-releases (which would take too much time), they can be built from
-source to unblock the update process.
-
-### xFormers
-
-```bash
-export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
-MAX_JOBS=16 uv pip install --system \
- --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
-```
-
## Update all the different vLLM platforms
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
diff --git a/docs/deployment/integrations/kserve.md b/docs/deployment/integrations/kserve.md
index edf79fca4f93e..37b29aa1a4876 100644
--- a/docs/deployment/integrations/kserve.md
+++ b/docs/deployment/integrations/kserve.md
@@ -2,4 +2,4 @@
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
-Please see [this guide](https://kserve.github.io/website/latest/modelserving/v1beta1/llm/huggingface/) for more details on using vLLM with KServe.
+Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
diff --git a/docs/design/debug_vllm_compile.md b/docs/design/debug_vllm_compile.md
index 3b454e851b54e..408d2878309dd 100644
--- a/docs/design/debug_vllm_compile.md
+++ b/docs/design/debug_vllm_compile.md
@@ -9,7 +9,7 @@ TL;DR:
|----------|----------|-------------|
| --enforce-eager | enforce_eager=True | Turn off torch.compile and CUDAGraphs |
| -O.mode=0 | mode=CompilationMode.NONE | Turn off torch.compile only |
-| -O.cudagraph_mode=NONE | compilation_config=CompilationConfig(mode=CompilationMode.NONE) | Turn off CUDAGraphs only |
+| -O.cudagraph_mode=NONE | compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE) | Turn off CUDAGraphs only |
| -O.backend=eager | compilation_config=CompilationConfig(backend='eager') | Turn off TorchInductor |
## vLLM-torch.compile overview
@@ -151,6 +151,76 @@ To avoid this, please either:
2. wrap the branching logic into a custom operator. TorchDynamo does not
trace into custom operators.
+## Debugging constraint violations and dynamic shapes guards issues
+
+Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
+attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
+These guards typically appear when framework code, custom passes, or user code branches based on
+dynamic shape values.
+
+**Example:**
+
+```python
+if x > 10:
+ # path A
+else:
+ # path B
+```
+
+This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
+
+**vLLM's Assumption:**
+vLLM assumes that all guards added by torch.compile are safe to drop and will not
+constrain the compiled graph to specific input shapes. When this assumption is violated,
+it can cause issues that users need to debug.
+Some side effects that indicates this assumption is violated are runtime errors
+or `ConstraintViolationErrors`.
+
+A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
+a single value. If you encounter a constraint violation error or suspect that a dynamic
+shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
+help debug the issue:
+
+```sh
+# Online - using unbacked mode
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
+
+# Online - using backed_size_oblivious mode
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
+```
+
+```py
+# Offline - using unbacked mode
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+LLM(model, compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
+))
+
+# Offline - using backed_size_oblivious mode
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+LLM(model, compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
+))
+```
+
+These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
+
+- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
+- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
+
+For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
+
+### Printing guards
+
+To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
+
+```sh
+TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
+```
+
+Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
+causing guards to be added incorrectly.
+
## Debugging TorchInductor
TorchInductor takes a captured graph and then compiles it down to some Python code
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index f0d5a3e934f39..e54a9e2bc5e77 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
-- [`CompressedTensorsW4A4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4MoeMethod]
+- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
diff --git a/docs/design/torch_compile.md b/docs/design/torch_compile.md
index 27edc4f89201d..7b0b2c1e96978 100644
--- a/docs/design/torch_compile.md
+++ b/docs/design/torch_compile.md
@@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
+## Dynamic shapes and vllm guard dropping
+
+`torch.compile` is designed to guard on dynamic shapes with no hesitation
+when needed. This contradicts with vLLM's `torch.compile` approach of
+dropping the guards since many of those guards could be material.
+
+`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
+`torch.compile` guards on `backed` dynamic shapes and does not provide a
+guarantee that no guards will be added to them. User code, dynamo,
+inductor, and autograd all can add guards. Moreover, for 0/1
+specializations, backed symbols are specialized unconditionally to 0, 1,
+or >=2 even without encountering a branching on those ranges.
+
+On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
+on and are not 0/1 specialized. However, there is a possibility of
+throwing a data dependent error when a branch that requires their value is
+encountered and no explicit unbacked handling is defined. The framework is
+converging to a state where it won't throw DDE but rather pick general
+paths. One downside of using unbacked is missed optimization opportunities
+due to either perf bugs or picking general paths, also using a fixed
+non-example input-based hint (this will be fixed soon with override_hint
+API). An example of picking general paths is assuming input not contiguous
+in functions call contiguous() and reshape() when can't be symbolically proven
+with a change of introducing a clone.
+
+`backed_size_oblivious` is a flag that enables treating backed symbols as
+unbacked wherever explicit handling for unbacked is defined. With this
+mode, 0/1 specializations are mostly avoided in framework code and the
+default 0/1 specialization does not happen. However, there is still no
+guarantee that torch.compile won't guard, especially due to user code or
+custom passes. `backed_size_oblivious` is experimental in PyTorch compile
+and could be deprecated. That said, it's a safer option to use than
+`backed` and the probability of reducing performance is lower than
+`unbacked`.
+
+### Configuring Dynamic Shapes
+
+The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
+setting the `type` field. You can choose between three modes:
+`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
+
+#### Offline Inference Example (Using LLM class)
+
+When using the `LLM` class for offline inference, you can configure dynamic
+shapes through the `compilation_config` parameter:
+
+```python
+from vllm import LLM, SamplingParams
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+
+# Example: Using backed_size_oblivious (experimental, safer than backed)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B",
+ compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(
+ type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
+ )
+ )
+)
+
+# Example: Using unbacked (strongest guarantee against guards)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B",
+ compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(
+ type=DynamicShapesType.UNBACKED
+ )
+ )
+)
+
+# Generate outputs
+prompts = ["Hello, my name is", "The future of AI is"]
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+outputs = llm.generate(prompts, sampling_params)
+```
+
+#### Online Serving Example (Using vllm serve)
+
+When using `vllm serve` for online serving, you can configure dynamic shapes
+through the `--compilation-config` flag:
+
+```bash
+# Example: Using unbacked
+vllm serve meta-llama/Llama-3.2-1B \
+ --compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'
+
+
+# Alternative: Using dot notation (simpler for single values)
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
+```
+
+#### Choosing the Right Mode
+
+- **BACKED** (default): Use when you're willing to accept potential unsafe dropping of guards
+for maximal performance. Guard could be unsoundly added and then ignored.
+
+- **UNBACKED** Use when you need the strongest guarantee against guards.
+ This is the most conservative option but may miss some optimization opportunities.
+
+- **BACKED_SIZE_OBLIVIOUS**: Use when you want a balance between avoiding guards
+ and performance. This experimental mode is safer than BACKED but still not as
+ conservative as UNBACKED.
+
## Python Code Compilation
In the very verbose logs, we can see:
@@ -122,7 +225,7 @@ When all the shapes are known, `torch.compile` can compare different configs, an
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
- mm 0.0160 ms 81.6%
+ mm 0.0160 ms 81.6%
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md
index 5e86e9388f328..9875bc44c9144 100644
--- a/docs/features/quantization/inc.md
+++ b/docs/features/quantization/inc.md
@@ -22,9 +22,6 @@ export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxab
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
```
-!!! tip
- If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
-
!!! tip
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md
index e38627c707884..7d52891bea7b9 100644
--- a/docs/features/structured_outputs.md
+++ b/docs/features/structured_outputs.md
@@ -7,7 +7,7 @@ This document shows you some examples of the different options that are
available to generate structured outputs.
!!! warning
- If you are still using the following deprecated API fields, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
+ If you are still using the following deprecated API fields which were removed in v0.12.0, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
- `guided_json` -> `{"structured_outputs": {"json": ...}}` or `StructuredOutputsParams(json=...)`
- `guided_regex` -> `{"structured_outputs": {"regex": ...}}` or `StructuredOutputsParams(regex=...)`
diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md
index 9e86f785b10c7..94920dc5306b3 100644
--- a/docs/getting_started/quickstart.md
+++ b/docs/getting_started/quickstart.md
@@ -283,7 +283,7 @@ Currently, vLLM supports multiple backends for efficient Attention computation a
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
-- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`.
+- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 626904a974155..25579835faf63 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
+| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |
@@ -701,6 +702,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ |
| `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
+| `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + IE+ | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
| `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | |
| `PaddleOCRVLForConditionalGeneration` | Paddle-OCR | T + I+ | `PaddlePaddle/PaddleOCR-VL`, etc. | | |
diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md
index 23df3963823aa..e3280bd15b55c 100644
--- a/docs/serving/openai_compatible_server.md
+++ b/docs/serving/openai_compatible_server.md
@@ -49,7 +49,8 @@ We currently support the following OpenAI APIs:
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
- - *Note: `parallel_tool_calls` and `user` parameters are ignored.*
+ - *Note: `user` parameter is ignored.*
+ - *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls.
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.md).
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md
index 14cd3b057791c..a32840ea73b9a 100644
--- a/docs/serving/parallelism_scaling.md
+++ b/docs/serving/parallelism_scaling.md
@@ -118,14 +118,16 @@ The common practice is to set the tensor parallel size to the number of GPUs in
```bash
vllm serve /path/to/the/model/in/the/container \
--tensor-parallel-size 8 \
- --pipeline-parallel-size 2
+ --pipeline-parallel-size 2 \
+ --distributed-executor-backend ray
```
Alternatively, you can set `tensor_parallel_size` to the total number of GPUs in the cluster:
```bash
vllm serve /path/to/the/model/in/the/container \
- --tensor-parallel-size 16
+ --tensor-parallel-size 16 \
+ --distributed-executor-backend ray
```
## Optimizing network communication for tensor parallelism
diff --git a/docs/usage/reproducibility.md b/docs/usage/reproducibility.md
index afc25b63902e2..a8e49d0a3398f 100644
--- a/docs/usage/reproducibility.md
+++ b/docs/usage/reproducibility.md
@@ -1,21 +1,23 @@
# Reproducibility
vLLM does not guarantee the reproducibility of the results by default, for the sake of performance. To achieve
-reproducible results, you need to turn off multiprocessing to make the scheduling deterministic by setting `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
+reproducible results:
+
+- In offline mode, you can either set `VLLM_ENABLE_V1_MULTIPROCESSING=0` which makes scheduling deterministic,
+ or enable [batch invariance](../features/batch_invariance.md) to make the outputs insensitive to scheduling.
+- In online mode, you can only enable [batch invariance](../features/batch_invariance.md).
Example: [examples/offline_inference/reproducibility.py](../../examples/offline_inference/reproducibility.py)
!!! warning
- Applying the above settings [changes the random state in user code](#locality-of-random-state).
+ Setting `VLLM_ENABLE_V1_MULTIPROCESSING=0` will change the random state of user code
+ (i.e. the code that constructs [LLM][vllm.LLM] class).
!!! note
Even with the above settings, vLLM only provides reproducibility
when it runs on the same hardware and the same vLLM version.
- Also, the online serving API (`vllm serve`) does not support reproducibility
- because it is almost impossible to make the scheduling deterministic in the
- online setting.
## Setting the global seed
@@ -23,25 +25,17 @@ The `seed` parameter in vLLM is used to control the random states for various ra
If a specific seed value is provided, the random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly.
-However, in some cases, setting the seed will also [change the random state in user code](#locality-of-random-state).
-
### Default Behavior
In V1, the `seed` parameter defaults to `0` which sets the random state for each worker, so the results will remain consistent for each vLLM run even if `temperature > 0`.
+It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs
+for workflows such as speculative decoding. For more information, see:
+
!!! note
- It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs
- for workflows such as speculative decoding.
-
- For more information, see:
+ The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM
+ only if the workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
-### Locality of random state
-
-The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM under the following conditions:
-
-- For V0: The seed is specified.
-- For V1: The workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
-
-By default, these conditions are not active so you can use vLLM without having to worry about
-accidentally making deterministic subsequent operations that rely on random state.
+ By default, `VLLM_ENABLE_V1_MULTIPROCESSING=1` so you can use vLLM without having to worry about
+ accidentally making deterministic subsequent operations that rely on random state.
diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md
index 22f4e6761ea9a..5f647aafd61d4 100644
--- a/docs/usage/v1_guide.md
+++ b/docs/usage/v1_guide.md
@@ -4,9 +4,7 @@
We have fully deprecated V0. Please read [RFC #18571](https://github.com/vllm-project/vllm/issues/18571) for more details.
-V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack).
-
-## Why vLLM V1?
+ If you have a use case that works on V0 Engine but not V1, please share it on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack).
vLLM V0 successfully supported a wide range of models and hardware, but as new features were developed independently, the system grew increasingly complex. This complexity made it harder to integrate new capabilities and introduced technical debt, revealing the need for a more streamlined and unified design.
@@ -32,16 +30,44 @@ Upgrade to vLLM’s Core Architecture](https://blog.vllm.ai/2025/01/27/v1-alpha-
This living user guide outlines a few known **important changes and limitations** introduced by vLLM V1. The team has been working actively to bring V1 as the default engine, therefore this guide will be updated constantly as more features get supported on vLLM V1.
-## Current Status
+## Differences from V0
-For each item, our progress towards V1 support falls into one of the following states:
+This section lists some differences in behavior between V0 and V1.
-- **🚀 Optimized**: Nearly fully optimized, with no further work currently planned.
-- **🟢 Functional**: Fully operational, with ongoing optimizations.
-- **🚧 WIP**: Under active development.
-- **🟡 Planned**: Scheduled for future implementation (some may have open PRs/RFCs).
-- **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later.
-- **🔴 Deprecated**: Not planned for V1 unless there is strong demand.
+### Chunked Prefill
+
+Chunked prefill is enabled by default whenever possible, unlike in V0 where it was conditionally enabled based on model characteristics.
+
+### CUDA Graphs
+
+CUDA graph capture takes up more memory in V1 than in V0.
+
+### Semantic Changes to Logprobs
+
+#### Logprobs Calculation
+
+By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
+before applying any logits post-processing such as temperature scaling or penalty
+adjustments). As a result, the returned logprobs do not reflect the final adjusted
+probabilities used during sampling.
+
+You can adjust this behavior by setting the `--logprobs-mode` flag.
+Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
+Raw means the values before applying any logit processors, like bad words.
+Processed means the values after applying all processors, including temperature and top_k/top_p.
+
+#### Prompt Logprobs with Prefix Caching
+
+While V1 supports passing prompt logprobs with prefix caching enabled, it no longer caches the logprobs.
+For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
+
+## Feature Support
+
+For each item, its support in vLLM V1 falls into one of the following states:
+
+- **🟢 Functional**: Fully operational with optimizations comparable to or better than V0.
+- **🟡 In Progress**: Planned to be in vLLM V1, with open PRs/RFCs.
+- **🔴 Removed**: Dropped from vLLM V1. Will only consider re-introducing if there is strong demand.
!!! note
vLLM V1’s unified scheduler treats both prompt and output tokens the same
@@ -57,13 +83,13 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
### Hardware
-| Hardware | Status |
-|------------|-----------------------------------------------|
-| **NVIDIA** | 🚀 |
-| **AMD** | 🟢 |
+| Hardware | Status |
+|------------------|-----------------------------------------------|
+| **NVIDIA** | 🟢 |
+| **AMD** | 🟢 |
| **INTEL GPU** | 🟢 |
-| **TPU** | 🟢 |
-| **CPU** | 🟢 (x86\_64/aarch64) 🟡 (MacOS) |
+| **TPU** | 🟢 |
+| **CPU** | 🟢 |
!!! note
@@ -78,23 +104,21 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
### Models
-| Model Type | Status |
-|-----------------------------|------------------------------------------------------------------------------------|
-| **Decoder-only Models** | 🚀 Optimized |
-| **Encoder-Decoder Models** | 🟢 Whisper only |
-| **Embedding Models** | 🟢 Functional |
-| **Mamba Models** | 🟢 (Mamba-2), 🟢 (Mamba-1) |
-| **Multimodal Models** | 🟢 Functional |
+| Model Type | Status |
+|-----------------------------|-------------------------------------------------------------------------|
+| **Decoder-only Models** | 🟢 |
+| **Encoder-Decoder Models** | 🟢 (Whisper), 🔴 (Others) |
+| **Pooling Models** | 🟢 |
+| **Mamba Models** | 🟢 |
+| **Multimodal Models** | 🟢 |
See below for the status of models that are not yet supported or have more features planned in V1.
-#### Embedding Models
+#### Pooling Models
-The initial basic support is now functional.
+Now fully supported, with prefix caching and chunked prefill newly available for last-pooling models.
-Later, we will consider using [hidden states processor](https://github.com/vllm-project/vllm/issues/12249),
-which is based on [global logits processor](https://github.com/vllm-project/vllm/pull/13360)
-to enable simultaneous generation and embedding using the same engine instance in V1.
+We are working on enabling prefix caching and chunked prefill for more categories of pooling models.
#### Mamba Models
@@ -112,24 +136,25 @@ Please note that prefix caching is not yet supported for any of the above models
Whisper is supported. Other models requiring cross-attention between separate
encoder and decoder (e.g., `BartForConditionalGeneration`,
-`MllamaForConditionalGeneration`) are not supported.
+`MllamaForConditionalGeneration`) are no longer supported.
### Features
| Feature | Status |
|---------------------------------------------|-----------------------------------------------------------------------------------|
-| **Prefix Caching** | 🚀 Optimized |
-| **Chunked Prefill** | 🚀 Optimized |
-| **LoRA** | 🚀 Optimized |
+| **Prefix Caching** | 🟢 Functional |
+| **Chunked Prefill** | 🟢 Functional |
+| **LoRA** | 🟢 Functional |
| **Logprobs Calculation** | 🟢 Functional |
-| **FP8 KV Cache** | 🟢 Functional on Hopper devices ()|
-| **Spec Decode** | 🚀 Optimized |
-| **Prompt Logprobs with Prefix Caching** | 🟡 Planned ([RFC #13414](https://github.com/vllm-project/vllm/issues/13414))|
+| **FP8 KV Cache** | 🟢 Functional |
+| **Spec Decode** | 🟢 Functional |
+| **Prompt Logprobs with Prefix Caching** | 🟢 Functional |
| **Structured Output Alternative Backends** | 🟢 Functional |
-| **Request-level Structured Output Backend** | 🔴 Deprecated |
-| **best_of** | 🔴 Deprecated ([RFC #13361](https://github.com/vllm-project/vllm/issues/13361))|
-| **Per-Request Logits Processors** | 🔴 Deprecated ([RFC #13360](https://github.com/vllm-project/vllm/pull/13360)) |
-| **GPU <> CPU KV Cache Swapping** | 🔴 Deprecated |
+| **Concurrent Partial Prefills** | 🟡 [In Progress](https://github.com/vllm-project/vllm/issues/14003) |
+| **best_of** | 🔴 [Removed](https://github.com/vllm-project/vllm/issues/13361) |
+| **Per-Request Logits Processors** | 🔴 [Removed](https://github.com/vllm-project/vllm/pull/13360) |
+| **GPU <> CPU KV Cache Swapping** | 🔴 Removed |
+| **Request-level Structured Output Backend** | 🔴 Removed |
!!! note
@@ -139,37 +164,16 @@ encoder and decoder (e.g., `BartForConditionalGeneration`,
prefix caching, and speculative decoding without a strict separation between prefill
and decode phases.
-#### Semantic Changes to Logprobs
+#### Removed Features
-vLLM V1 supports logprobs and prompt logprobs. However, there are some important semantic
-differences compared to V0:
-
-##### Logprobs Calculation
-
-By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
-before applying any logits post-processing such as temperature scaling or penalty
-adjustments). As a result, the returned logprobs do not reflect the final adjusted
-probabilities used during sampling.
-
-You can adjust this behavior by setting the `--logprobs-mode` flag.
-Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
-Raw means the values before applying any logit processors, like bad words.
-Processed means the values after applying all processors, including temperature and top_k/top_p.
-
-##### Prompt Logprobs with Prefix Caching
-
-Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
-
-#### Deprecated Features
-
-As part of the major architectural rework in vLLM V1, several legacy features have been deprecated.
+As part of the major architectural rework in vLLM V1, several legacy features have been removed.
##### Sampling features
-- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361).
+- **best_of**: This feature has been removed due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361).
- **Per-Request Logits Processors**: In V0, users could pass custom
processing functions to adjust logits on a per-request basis. In vLLM V1, this
- feature has been deprecated. Instead, we now support **global logits processors**
+ feature has been removed. Instead, we now support **global logits processors**
which are set at startup time, see [RFC #17799](https://github.com/vllm-project/vllm/issues/17799).
##### KV Cache features
@@ -179,4 +183,4 @@ to handle request preemptions.
##### Structured Output features
-- **Request-level Structured Output Backend**: Deprecated, alternative backends (outlines, guidance) with fallbacks is supported now.
+- **Request-level Structured Output Backend**: Removed; alternative backends (outlines, guidance) with fallbacks are supported now.
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
old mode 100644
new mode 100755
index 04e6f99f8957e..df6e96ca375fc
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -425,6 +425,13 @@ def parse_args():
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -434,6 +441,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
+ if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {args.tensor_parallel_size}"
+ )
+
audio_count = args.num_audios
req_data = model_example_map[model](
question_per_audio_count[audio_count], audio_count
@@ -446,6 +459,8 @@ def main(args):
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ if args.tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different
diff --git a/examples/offline_inference/qwen3_omni/only_thinker.py b/examples/offline_inference/qwen3_omni/only_thinker.py
new file mode 100644
index 0000000000000..88a61ed694c2e
--- /dev/null
+++ b/examples/offline_inference/qwen3_omni/only_thinker.py
@@ -0,0 +1,170 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+This example shows how to use vLLM for running offline inference
+with the correct prompt format on Qwen2.5-Omni (thinker only).
+"""
+
+from typing import NamedTuple
+
+from vllm import LLM, SamplingParams
+from vllm.assets.audio import AudioAsset
+from vllm.assets.image import ImageAsset
+from vllm.assets.video import VideoAsset
+from vllm.multimodal.image import convert_image_mode
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+
+class QueryResult(NamedTuple):
+ inputs: dict
+ limit_mm_per_prompt: dict[str, int]
+
+
+# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
+# lower-end GPUs.
+# Unless specified, these settings have been tested to work on a single L4.
+
+default_system = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
+ "Group, capable of perceiving auditory and visual inputs, as well as "
+ "generating text and speech."
+)
+
+
+def get_mixed_modalities_query() -> QueryResult:
+ question = (
+ "What is recited in the audio? "
+ "What is the content of this image? Why is this video funny?"
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
+ "<|vision_start|><|image_pad|><|vision_end|>"
+ "<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ "image": convert_image_mode(
+ ImageAsset("cherry_blossom").pil_image, "RGB"
+ ),
+ "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
+ )
+
+
+def get_use_audio_in_video_query() -> QueryResult:
+ question = (
+ "Describe the content of the video in details, then convert what the "
+ "baby say into text."
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ asset = VideoAsset(name="baby_reading", num_frames=16)
+ audio = asset.get_audio(sampling_rate=16000)
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "video": asset.np_ndarrays,
+ "audio": audio,
+ },
+ "mm_processor_kwargs": {
+ "use_audio_in_video": True,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "video": 1},
+ )
+
+
+def get_multi_audios_query() -> QueryResult:
+ question = "Are these two audio clips the same?"
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
+ "<|audio_start|><|audio_pad|><|audio_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "audio": [
+ AudioAsset("winning_call").audio_and_sample_rate,
+ AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ ],
+ },
+ },
+ limit_mm_per_prompt={
+ "audio": 2,
+ },
+ )
+
+
+query_map = {
+ "mixed_modalities": get_mixed_modalities_query,
+ "use_audio_in_video": get_use_audio_in_video_query,
+ "multi_audios": get_multi_audios_query,
+}
+
+
+def main(args):
+ model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+ query_result = query_map[args.query_type]()
+
+ llm = LLM(
+ model=model_name,
+ max_model_len=12800,
+ max_num_seqs=5,
+ limit_mm_per_prompt=query_result.limit_mm_per_prompt,
+ seed=args.seed,
+ )
+
+ # We set temperature to 0.2 so that outputs can be different
+ # even when all prompts are identical when running batch inference.
+ sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
+
+ outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
+
+ for o in outputs:
+ generated_text = o.outputs[0].text
+ print(generated_text)
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(
+ description="Demo on using vLLM for offline inference with "
+ "audio language models"
+ )
+ parser.add_argument(
+ "--query-type",
+ "-q",
+ type=str,
+ default="mixed_modalities",
+ choices=query_map.keys(),
+ help="Query type.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/reproducibility.py b/examples/offline_inference/reproducibility.py
index e135bc1b2abb7..72c1e841dca45 100644
--- a/examples/offline_inference/reproducibility.py
+++ b/examples/offline_inference/reproducibility.py
@@ -11,8 +11,11 @@ import random
from vllm import LLM, SamplingParams
-# Turn off multiprocessing to make the scheduling deterministic.
+# Either:
+## Turn off multiprocessing to make the scheduling deterministic, or
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
+## Enable batch invariance to get consistent results regardless of scheduling.
+os.environ["VLLM_BATCH_INVARIANT"] = "1"
prompts = [
"Hello, my name is",
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
old mode 100644
new mode 100755
index 624de2a2debc3..8f72bf6f0b0d1
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
)
+# HunyuanOCR
+def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "tencent/HunyuanOCR"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+ prompts = [
+ f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>"
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=None,
+ )
+
+
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
def run_hyperclovax_seed_vision(
questions: list[str], modality: str
@@ -1820,6 +1845,7 @@ model_example_map = {
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"h2ovl_chat": run_h2ovl,
+ "hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3,
"interns1": run_interns1,
@@ -2038,6 +2064,13 @@ def parse_args():
help="If True, will send all requests in a second batch with empty mm "
"data to verify cache hits with UUIDs.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -2046,6 +2079,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
+ if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {args.tensor_parallel_size}"
+ )
+
modality = args.modality
mm_input = get_multi_modal_input(args)
data = mm_input["data"]
@@ -2063,6 +2102,8 @@ def main(args):
"seed": args.seed,
"mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
}
+ if args.tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`.
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
old mode 100644
new mode 100755
index d6e169548f15b..7ba4e64b567de
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -1110,6 +1110,7 @@ def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model=model_name,
max_model_len=16384,
max_num_seqs=16,
+ trust_remote_code=True,
limit_mm_per_prompt={"image": len(image_urls)},
)
@@ -1351,10 +1352,18 @@ model_example_map = {
}
-def run_generate(model, question: str, image_urls: list[str], seed: int | None):
+def run_generate(
+ model,
+ question: str,
+ image_urls: list[str],
+ seed: int | None,
+ tensor_parallel_size: int | None,
+):
req_data = model_example_map[model](question, image_urls)
- engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ engine_args = asdict(req_data.engine_args) | {"seed": seed}
+ if tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = tensor_parallel_size
llm = LLM(**engine_args)
sampling_params = SamplingParams(
@@ -1377,7 +1386,13 @@ def run_generate(model, question: str, image_urls: list[str], seed: int | None):
print("-" * 50)
-def run_chat(model: str, question: str, image_urls: list[str], seed: int | None):
+def run_chat(
+ model: str,
+ question: str,
+ image_urls: list[str],
+ seed: int | None,
+ tensor_parallel_size: int | None,
+):
req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory
@@ -1387,6 +1402,8 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
+ if tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = tensor_parallel_size
llm = LLM(**engine_args)
sampling_params = (
@@ -1462,6 +1479,13 @@ def parse_args():
default=2,
help="Number of images to use for the demo.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -1469,13 +1493,20 @@ def main(args: Namespace):
model = args.model_type
method = args.method
seed = args.seed
+ tensor_parallel_size = args.tensor_parallel_size
+
+ if tensor_parallel_size is not None and tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {tensor_parallel_size}"
+ )
image_urls = IMAGE_URLS[: args.num_images]
if method == "generate":
- run_generate(model, QUESTION, image_urls, seed)
+ run_generate(model, QUESTION, image_urls, seed, tensor_parallel_size)
elif method == "chat":
- run_chat(model, QUESTION, image_urls, seed)
+ run_chat(model, QUESTION, image_urls, seed, tensor_parallel_size)
else:
raise ValueError(f"Invalid method: {method}")
diff --git a/examples/online_serving/gradio_openai_chatbot_webserver.py b/examples/online_serving/gradio_openai_chatbot_webserver.py
index d5d0a07a29183..c76c60cc4472d 100644
--- a/examples/online_serving/gradio_openai_chatbot_webserver.py
+++ b/examples/online_serving/gradio_openai_chatbot_webserver.py
@@ -25,25 +25,17 @@ import gradio as gr
from openai import OpenAI
-def format_history_to_openai(history):
- history_openai_format = [
- {"role": "system", "content": "You are a great AI assistant."}
- ]
- for human, assistant in history:
- history_openai_format.append({"role": "user", "content": human})
- history_openai_format.append({"role": "assistant", "content": assistant})
- return history_openai_format
-
-
def predict(message, history, client, model_name, temp, stop_token_ids):
- # Format history to OpenAI chat format
- history_openai_format = format_history_to_openai(history)
- history_openai_format.append({"role": "user", "content": message})
+ messages = [
+ {"role": "system", "content": "You are a great AI assistant."},
+ *history,
+ {"role": "user", "content": message},
+ ]
# Send request to OpenAI API (vLLM server)
stream = client.chat.completions.create(
model=model_name,
- messages=history_openai_format,
+ messages=messages,
temperature=temp,
stream=True,
extra_body={
diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh
index 1577de85f7ff2..b5c92749466b0 100644
--- a/examples/online_serving/openai_embedding_long_text/service.sh
+++ b/examples/online_serving/openai_embedding_long_text/service.sh
@@ -22,7 +22,6 @@ API_KEY=${API_KEY:-"your-api-key"}
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
export VLLM_ENABLE_CHUNKED_PROCESSING=true
export CUDA_VISIBLE_DEVICES=2,3,4,5
-# export VLLM_ATTENTION_BACKEND=XFORMERS
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
echo "=================================================================="
diff --git a/examples/online_serving/openai_responses_client.py b/examples/online_serving/openai_responses_client.py
new file mode 100644
index 0000000000000..b4eb24671507a
--- /dev/null
+++ b/examples/online_serving/openai_responses_client.py
@@ -0,0 +1,44 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Set up this example by starting a vLLM OpenAI-compatible server.
+Reasoning models can be used through the Responses API as seen here
+https://platform.openai.com/docs/api-reference/responses
+For example:
+vllm serve Qwen/Qwen3-8B --reasoning-parser qwen3
+
+"""
+
+from openai import OpenAI
+
+input_messages = [{"role": "user", "content": "What model are you?"}]
+
+
+def main():
+ base_url = "http://localhost:8000/v1"
+ client = OpenAI(base_url=base_url, api_key="empty")
+ model = "Qwen/Qwen3-8B" # get_first_model(client)
+ response = client.responses.create(
+ model=model,
+ input=input_messages,
+ )
+
+ for message in response.output:
+ if message.type == "reasoning":
+ # append reasoning message
+ input_messages.append(message)
+
+ response_2 = client.responses.create(
+ model=model,
+ input=input_messages,
+ )
+ print(response_2.output_text)
+ # I am Qwen, a large language model developed by Alibaba Cloud.
+ # I am designed to assist with a wide range of tasks, including
+ # answering questions, creating content, coding, and engaging in
+ # conversations. I can help with various topics and provide
+ # information or support in multiple languages. How can I assist you today?
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements/cuda.txt b/requirements/cuda.txt
index d63fe9e1e77c1..15e8aadc56f47 100644
--- a/requirements/cuda.txt
+++ b/requirements/cuda.txt
@@ -9,6 +9,5 @@ torch==2.9.0
torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
-xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.2
diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt
index b1f3269cd3813..083230c171096 100644
--- a/requirements/kv_connectors.txt
+++ b/requirements/kv_connectors.txt
@@ -1,2 +1,2 @@
lmcache
-nixl >= 0.6.0 # Required for disaggregated prefill
+nixl >= 0.7.1 # Required for disaggregated prefill
diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt
index 432e11977872d..8a91b59de6f72 100644
--- a/requirements/rocm-test.txt
+++ b/requirements/rocm-test.txt
@@ -39,3 +39,13 @@ mteb[bm25s]>=1.38.11, <2
# Required for eval tests
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
+
+# Required for multiprocessed tests that use spawn method
+multiprocess==0.70.16
+
+# Plugins test
+terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
+torchgeo==0.7.0
+
+# Required for suffix decoding test
+arctic-inference == 0.1.1
diff --git a/requirements/xpu.txt b/requirements/xpu.txt
index 59ea710684a2c..c1dc4195b5231 100644
--- a/requirements/xpu.txt
+++ b/requirements/xpu.txt
@@ -10,9 +10,9 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.61.2 # Required for N-gram speculative decoding
-torch==2.8.0+xpu
+--extra-index-url=https://download.pytorch.org/whl/xpu
+torch==2.9.0+xpu
torchaudio
torchvision
---extra-index-url=https://download.pytorch.org/whl/xpu
-intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post1%2Bxpu-cp312-cp312-linux_x86_64.whl
+intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.9.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl
diff --git a/setup.py b/setup.py
index 5591bcb132447..8871b04d8fc46 100644
--- a/setup.py
+++ b/setup.py
@@ -74,18 +74,6 @@ def is_ninja_available() -> bool:
return which("ninja") is not None
-def is_url_available(url: str) -> bool:
- from urllib.request import urlopen
-
- status = None
- try:
- with urlopen(url) as f:
- status = f.status
- except Exception:
- return False
- return status == 200
-
-
class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
super().__init__(name, sources=[], py_limited_api=True, **kwa)
@@ -533,28 +521,6 @@ def get_nvcc_cuda_version() -> Version:
return nvcc_cuda_version
-def get_gaudi_sw_version():
- """
- Returns the driver version.
- """
- # Enable console printing for `hl-smi` check
- output = subprocess.run(
- "hl-smi",
- shell=True,
- text=True,
- capture_output=True,
- env={"ENABLE_CONSOLE": "true"},
- )
- if output.returncode == 0 and output.stdout:
- return (
- output.stdout.split("\n")[2]
- .replace(" ", "")
- .split(":")[1][:-1]
- .split("-")[0]
- )
- return "0.0.0" # when hl-smi is not available
-
-
def get_vllm_version() -> str:
# Allow overriding the version. This is useful to build platform-specific
# wheels (e.g. CPU, TPU) without modifying the source.
diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py
index 0cf1e85d4e8ee..521d6c33dd390 100644
--- a/tests/basic_correctness/test_basic_correctness.py
+++ b/tests/basic_correctness/test_basic_correctness.py
@@ -74,9 +74,6 @@ def test_models(
model_executor: str,
enable_prompt_embeds: bool,
) -> None:
- if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
- pytest.skip(f"{backend} does not support gemma2 with full context length.")
-
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend)
diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py
index 661172e1965b5..53c3f875d2003 100644
--- a/tests/compile/distributed/test_fusions_e2e.py
+++ b/tests/compile/distributed/test_fusions_e2e.py
@@ -111,6 +111,17 @@ if current_platform.is_cuda():
async_tp=96, # MLP is MoE, half the fusions of dense
),
),
+ ModelBackendTestCase(
+ model_name="openai/gpt-oss-20b",
+ model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
+ backend=AttentionBackendEnum.FLASHINFER,
+ matches=Matches(
+ attention_fusion=0,
+ allreduce_fusion=49,
+ sequence_parallel=49,
+ async_tp=48,
+ ),
+ ),
]
elif current_platform.is_rocm():
diff --git a/tests/compile/fullgraph/test_simple.py b/tests/compile/fullgraph/test_simple.py
index e258133ab50a7..36cc1510ed798 100644
--- a/tests/compile/fullgraph/test_simple.py
+++ b/tests/compile/fullgraph/test_simple.py
@@ -55,7 +55,7 @@ class SillyModel(nn.Module):
def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
- use_inductor,
+ backend,
expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
@@ -64,7 +64,7 @@ def _run_simple_model(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
- use_inductor=use_inductor,
+ backend=backend,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
@@ -124,14 +124,14 @@ def _run_simple_model(
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
-@pytest.mark.parametrize("use_inductor", [True, False])
+@pytest.mark.parametrize("backend", ["inductor", "eager"])
@torch.inference_mode()
@create_new_process_for_each_test("spawn")
-def test_simple_piecewise_compile(use_inductor):
+def test_simple_piecewise_compile(backend):
_run_simple_model(
splitting_ops=["silly::attention"],
use_inductor_graph_partition=False,
- use_inductor=use_inductor,
+ backend=backend,
# 2 * num_layers + 1
expected_num_piecewise_graphs_seen=5,
# 1 + num_layers
@@ -155,7 +155,7 @@ def test_simple_inductor_graph_partition(monkeypatch):
_run_simple_model(
splitting_ops=["silly::attention"],
use_inductor_graph_partition=True,
- use_inductor=True,
+ backend="inductor",
# Since not splitting at fx graph level
expected_num_piecewise_graphs_seen=1,
# Since not splitting at fx graph level
diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py
new file mode 100644
index 0000000000000..c20aea822fe81
--- /dev/null
+++ b/tests/compile/test_dynamic_shapes_compilation.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import gc
+
+import pytest
+import torch
+
+from vllm import LLM, SamplingParams
+from vllm.config.compilation import CompilationMode, DynamicShapesType
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from vllm.utils.torch_utils import is_torch_equal_or_newer
+
+
+def get_test_models():
+ """Get list of models to test based on PyTorch version"""
+ # TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
+ return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
+
+
+@pytest.mark.parametrize("model_name", get_test_models())
+@pytest.mark.parametrize(
+ "shapes_type",
+ [
+ DynamicShapesType.BACKED,
+ DynamicShapesType.UNBACKED,
+ DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
+ ],
+)
+@pytest.mark.parametrize("use_aot_compile", ["0"])
+@pytest.mark.parametrize("use_bytecode_hook", [True, False])
+@pytest.mark.skipif(
+ not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
+)
+def test_dynamic_shapes_compilation(
+ monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
+):
+ """Test that all dynamic shapes types compile successfully"""
+ print(
+ f"\nTesting model: {model_name} with {shapes_type.name}, "
+ f"AOT compile: {use_aot_compile}, "
+ f"Bytecode hook: {use_bytecode_hook}"
+ )
+ if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
+ pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
+
+ monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
+ monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
+
+ prompt = "Hello, my name is"
+
+ print(f"Testing {shapes_type.name} dynamic shapes...")
+
+ # Initialize the model with specific dynamic shapes configuration
+ model = LLM(
+ model=model_name,
+ compilation_config={
+ "mode": CompilationMode.VLLM_COMPILE,
+ "dynamic_shapes_config": {
+ "type": shapes_type.value,
+ },
+ },
+ )
+
+ output = model.generate(prompt)
+ result = output[0].outputs[0].text
+ # Example of setting the sampling parameters
+ tokenizer = get_tokenizer(model_name)
+ yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
+ no_tokens = tokenizer.encode("no", add_special_tokens=False)
+ allowed_ids = list(set(yes_tokens + no_tokens))
+ sampling_params = SamplingParams(
+ max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
+ )
+
+ output = model.generate(
+ "answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
+ sampling_params=sampling_params,
+ )
+ result = output[0].outputs[0].text
+ assert result == "yes"
+
+ # Clean up GPU memory
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ print("GPU memory cleared")
diff --git a/tests/conftest.py b/tests/conftest.py
index b17081352edcf..163593eb3f14f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -748,6 +748,14 @@ class VllmRunner:
# being captured which can trigger edge cases that we don't handle yet.
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
+ # Make sure we have atleast one cudagraph large enough for a single decode.
+ if (speculative_config := kwargs.get("speculative_config")) and (
+ num_speculative_tokens := speculative_config["num_speculative_tokens"]
+ ):
+ kwargs["compilation_config"]["cudagraph_capture_sizes"].append(
+ num_speculative_tokens + 1
+ )
+
with init_ctx:
self.llm = LLM(
model=model_name,
@@ -845,6 +853,7 @@ class VllmRunner:
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: list[RequestOutput],
+ include_prompt_token_ids: bool = False,
) -> list[TokensTextLogprobsPromptLogprobs]:
outputs: list[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
@@ -853,9 +862,26 @@ class VllmRunner:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
- outputs.append(
- (output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
- )
+ if include_prompt_token_ids:
+ outputs.append(
+ ( # type: ignore[arg-type]
+ output_ids,
+ output_str,
+ output_logprobs,
+ req_output.prompt_token_ids,
+ req_output.prompt_logprobs,
+ )
+ )
+ else:
+ outputs.append(
+ (
+ output_ids,
+ output_str,
+ output_logprobs,
+ req_output.prompt_logprobs,
+ )
+ )
+
return outputs
def generate_w_logprobs(
@@ -865,6 +891,7 @@ class VllmRunner:
images: PromptImageInput | None = None,
audios: PromptAudioInput | None = None,
videos: PromptVideoInput | None = None,
+ include_prompt_token_ids: bool = False,
**kwargs: Any,
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
@@ -874,7 +901,7 @@ class VllmRunner:
)
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
- req_outputs
+ req_outputs, include_prompt_token_ids
)
# Omit prompt logprobs if not required by sampling params
return (
diff --git a/tests/distributed/eplb_utils.py b/tests/distributed/eplb_utils.py
new file mode 100644
index 0000000000000..27a63e0215148
--- /dev/null
+++ b/tests/distributed/eplb_utils.py
@@ -0,0 +1,49 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+import random
+
+import torch
+import torch.multiprocessing as mp
+
+from vllm.distributed.parallel_state import (
+ init_distributed_environment,
+)
+from vllm.utils.system_utils import update_environment_variables
+
+mp.set_start_method("spawn", force=True)
+
+
+def distributed_run(fn, world_size, *args):
+ number_of_processes = world_size
+ processes: list[mp.Process] = []
+ for i in range(number_of_processes):
+ env: dict[str, str] = {}
+ env["RANK"] = str(i)
+ env["LOCAL_RANK"] = str(i)
+ env["WORLD_SIZE"] = str(number_of_processes)
+ env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
+ env["MASTER_ADDR"] = "localhost"
+ env["MASTER_PORT"] = "12345"
+ p = mp.Process(target=fn, args=(env, world_size, *args))
+ processes.append(p)
+ p.start()
+
+ for p in processes:
+ p.join()
+
+ for p in processes:
+ assert p.exitcode == 0
+
+
+def set_env_vars_and_device(env: dict[str, str]) -> None:
+ update_environment_variables(env)
+ local_rank = os.environ["LOCAL_RANK"]
+ device = torch.device(f"cuda:{local_rank}")
+ torch.cuda.set_device(device)
+ init_distributed_environment()
+
+ # Ensure each worker process has the same random seed
+ random.seed(42)
+ torch.manual_seed(42)
diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py
index 0a97749ac318c..781dfd44c1ef6 100644
--- a/tests/distributed/test_eplb_execute.py
+++ b/tests/distributed/test_eplb_execute.py
@@ -1,57 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import os
+import asyncio
import random
import pytest
import torch
import torch.distributed
-import torch.multiprocessing as mp
-from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
+from vllm.distributed.eplb.rebalance_execute import (
+ move_from_buffer,
+ rearrange_expert_weights_inplace,
+ transfer_layer,
+)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_tp_group,
- init_distributed_environment,
)
-from vllm.utils.system_utils import update_environment_variables
-mp.set_start_method("spawn", force=True)
-
-
-def distributed_run(fn, world_size, *args):
- number_of_processes = world_size
- processes: list[mp.Process] = []
- for i in range(number_of_processes):
- env: dict[str, str] = {}
- env["RANK"] = str(i)
- env["LOCAL_RANK"] = str(i)
- env["WORLD_SIZE"] = str(number_of_processes)
- env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
- env["MASTER_ADDR"] = "localhost"
- env["MASTER_PORT"] = "12345"
- p = mp.Process(target=fn, args=(env, world_size, *args))
- processes.append(p)
- p.start()
-
- for p in processes:
- p.join()
-
- for p in processes:
- assert p.exitcode == 0
-
-
-def set_env_vars_and_device(env: dict[str, str]) -> None:
- update_environment_variables(env)
- local_rank = os.environ["LOCAL_RANK"]
- device = torch.device(f"cuda:{local_rank}")
- torch.cuda.set_device(device)
- init_distributed_environment()
-
- # Ensure each worker process has the same random seed
- random.seed(42)
- torch.manual_seed(42)
+from .eplb_utils import distributed_run, set_env_vars_and_device
def create_expert_indices_with_redundancy(
@@ -269,6 +236,100 @@ def verify_redundant_experts_have_same_weights(
)
+def _test_async_transfer_layer_without_mtp_worker(
+ env,
+ world_size: int,
+ num_layers: int,
+ num_local_experts: int,
+ num_logical_experts: int,
+) -> None:
+ set_env_vars_and_device(env)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
+
+ tp_group = get_tp_group()
+ ep_group = tp_group.device_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ total_physical_experts = world_size * num_local_experts
+ hidden_sizes = [16, 32]
+
+ redundancy_config = create_redundancy_config(
+ num_logical_experts,
+ total_physical_experts,
+ )
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ redundancy_config,
+ )
+
+ new_redundancy_config = create_redundancy_config(
+ num_logical_experts,
+ total_physical_experts,
+ )
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ new_redundancy_config,
+ )
+
+ expert_weights = create_expert_weights(
+ num_layers,
+ num_local_experts,
+ hidden_sizes,
+ ep_rank,
+ device,
+ old_indices,
+ )
+
+ expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
+ cuda_stream = torch.cuda.Stream(device=device)
+
+ for layer_idx in range(num_layers):
+ is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
+ transfer_layer(
+ old_global_expert_indices=old_indices,
+ new_global_expert_indices=new_indices,
+ expert_weights=expert_weights,
+ expert_weights_buffer=expert_buffer,
+ ep_group=ep_group,
+ layer=layer_idx,
+ cuda_stream=cuda_stream,
+ )
+ )
+
+ cuda_stream.synchronize()
+ move_from_buffer(
+ expert_weights=expert_weights[layer_idx],
+ expert_weights_buffer=expert_buffer,
+ is_unchanged=is_unchanged,
+ is_received_locally=is_received_locally,
+ experts_recv_loc=experts_recv_loc,
+ new_indices=new_indices[layer_idx].tolist(),
+ ep_group=ep_group,
+ )
+
+ verify_expert_weights_after_shuffle(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ ep_rank,
+ num_local_experts,
+ )
+ verify_redundant_experts_have_same_weights(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ world_size,
+ num_local_experts,
+ )
+
+
def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts
) -> None:
@@ -437,6 +498,32 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
)
+@pytest.mark.parametrize(
+ "world_size,num_layers,num_local_experts,num_logical_experts",
+ [
+ (2, 2, 2, 3),
+ ],
+)
+def test_async_transfer_layer_without_mtp(
+ world_size: int,
+ num_layers: int,
+ num_local_experts: int,
+ num_logical_experts: int,
+):
+ """Exercise async EPLB transfer path without MTP/spec decode."""
+
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ distributed_run(
+ _test_async_transfer_layer_without_mtp_worker,
+ world_size,
+ num_layers,
+ num_local_experts,
+ num_logical_experts,
+ )
+
+
@pytest.mark.parametrize("world_size", [2, 4])
def test_rearrange_expert_weights_no_change(world_size):
"""
diff --git a/tests/distributed/test_eplb_fused_moe_layer.py b/tests/distributed/test_eplb_fused_moe_layer.py
new file mode 100644
index 0000000000000..55f26519887a1
--- /dev/null
+++ b/tests/distributed/test_eplb_fused_moe_layer.py
@@ -0,0 +1,285 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Test that the interaction between EPLB and FusedMoE Layer is okay
+
+from dataclasses import dataclass
+
+import pytest
+import torch
+
+from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
+from vllm.distributed.parallel_state import (
+ ensure_model_parallel_initialized,
+ get_tp_group,
+)
+from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+
+from .eplb_utils import distributed_run, set_env_vars_and_device
+
+
+@dataclass
+class TestConfig:
+ num_layers: int
+ num_experts: int
+ num_local_experts: int
+ num_topk: int
+ hidden_size: int
+ intermediate_size: int
+ weight_dtype: torch.dtype
+ weight_scale_dtype: torch.dtype | None
+ column_major_scales: bool
+
+
+def make_expert_weights(
+ layer_idx: int,
+ global_expert_idx: int,
+ global_num_experts: int,
+ tensor_shape: tuple[int, ...],
+ tensor_dtype: torch.dtype,
+ tensor_device: torch.device,
+ is_column_major: bool,
+) -> torch.Tensor:
+ assert len(tensor_shape) == 2
+
+ if is_column_major:
+ tensor_shape = (tensor_shape[1], tensor_shape[0])
+
+ x = torch.empty(tensor_shape, dtype=tensor_dtype, device=tensor_device)
+ value_offset = (layer_idx * global_num_experts + global_expert_idx) * x.numel()
+ x.view(-1).copy_(
+ torch.arange(
+ value_offset,
+ value_offset + x.numel(),
+ dtype=tensor_dtype,
+ device=tensor_device,
+ )
+ )
+
+ if is_column_major:
+ x = torch.transpose(x, 1, 0)
+ assert not x.is_contiguous()
+ return x
+
+
+def make_fused_moe_layer(
+ rank: int,
+ layer_idx: int,
+ test_config: TestConfig,
+) -> FusedMoE:
+ fml = FusedMoE(
+ num_experts=test_config.num_experts,
+ top_k=test_config.num_topk,
+ hidden_size=test_config.hidden_size,
+ intermediate_size=test_config.intermediate_size,
+ prefix=f"dummy_layer_{layer_idx}",
+ activation="silu",
+ is_act_and_mul=True,
+ params_dtype=test_config.weight_dtype,
+ )
+
+ device = torch.device(f"cuda:{rank}")
+
+ from functools import partial
+
+ _make_expert_weights = partial(
+ make_expert_weights,
+ layer_idx=layer_idx,
+ global_num_experts=test_config.num_experts,
+ tensor_device=device,
+ )
+
+ assert isinstance(fml.w13_weight.data, torch.Tensor)
+ assert isinstance(fml.w2_weight.data, torch.Tensor)
+ fml.w13_weight.data = fml.w13_weight.data.to(device=device)
+ fml.w2_weight.data = fml.w2_weight.data.to(device=device)
+ w13_weight = fml.w13_weight.data
+ w2_weight = fml.w2_weight.data
+ assert w13_weight.size(0) == test_config.num_local_experts
+ for i in range(test_config.num_local_experts):
+ g_i = rank * test_config.num_local_experts + i
+ w13_weight_e = w13_weight[i]
+ w2_weight_e = w2_weight[i]
+ w13_weight_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w13_weight_e.shape,
+ tensor_dtype=w13_weight_e.dtype,
+ is_column_major=False,
+ )
+ )
+ w2_weight_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w2_weight_e.shape,
+ tensor_dtype=w2_weight_e.dtype,
+ is_column_major=False,
+ )
+ )
+
+ block_size = 16
+
+ def block_quant_scales_shape(
+ shape: tuple[int, ...], is_column_major: bool
+ ) -> tuple[int, ...]:
+ assert len(shape) == 3
+ if not is_column_major:
+ return (shape[0], shape[1] // block_size, shape[2] // block_size)
+ else:
+ return (shape[0], shape[2] // block_size, shape[1] // block_size)
+
+ is_column_major = test_config.column_major_scales
+ w13_weight_scale_inv = torch.empty(
+ block_quant_scales_shape(w13_weight.shape, is_column_major),
+ dtype=test_config.weight_dtype,
+ device=device,
+ )
+ w2_weight_scale_inv = torch.empty(
+ block_quant_scales_shape(w2_weight.shape, is_column_major),
+ dtype=test_config.weight_dtype,
+ device=device,
+ )
+
+ for i in range(test_config.num_local_experts):
+ g_i = rank * test_config.num_local_experts + i
+ w13_s_e = w13_weight_scale_inv[i]
+ w2_s_e = w2_weight_scale_inv[i]
+ w13_s_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w13_s_e.shape,
+ tensor_dtype=w13_s_e.dtype,
+ # Fill data in row-major and then
+ # transpose if test_config requires col-major.
+ is_column_major=False,
+ )
+ )
+ w2_s_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w2_s_e.shape,
+ tensor_dtype=w2_s_e.dtype,
+ is_column_major=False,
+ )
+ )
+ if is_column_major:
+ w13_weight_scale_inv = torch.transpose(w13_weight_scale_inv, 1, 2)
+ w2_weight_scale_inv = torch.transpose(w2_weight_scale_inv, 1, 2)
+ assert not w13_weight_scale_inv.is_contiguous()
+ assert not w2_weight_scale_inv.is_contiguous()
+
+ # Add scales to the parameter list
+ fml.w13_weight_scale_inv = torch.nn.Parameter(
+ w13_weight_scale_inv, requires_grad=False
+ )
+ fml.w2_weight_scale_inv = torch.nn.Parameter(
+ w2_weight_scale_inv, requires_grad=False
+ )
+
+ return fml
+
+
+def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
+ # Initialize model parallel (using tensor parallel as an entrypoint
+ # to expert parallel)
+ set_env_vars_and_device(env)
+
+ vllm_config = VllmConfig()
+ vllm_config.parallel_config.tensor_parallel_size = world_size
+ vllm_config.parallel_config.enable_expert_parallel = True
+
+ with set_current_vllm_config(vllm_config):
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
+
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+
+ fml_layers = [
+ make_fused_moe_layer(ep_rank, layer_idx, test_config)
+ for layer_idx in range(test_config.num_layers)
+ ]
+ rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
+
+ indices = torch.zeros(
+ test_config.num_layers, test_config.num_experts, dtype=torch.long
+ )
+ for lidx in range(test_config.num_layers):
+ indices[lidx] = torch.Tensor(range(test_config.num_experts))
+
+ shuffled_indices = torch.zeros_like(indices)
+ for lidx in range(test_config.num_layers):
+ shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
+
+ rearrange_expert_weights_inplace(
+ indices,
+ shuffled_indices,
+ rank_expert_weights,
+ ep_group,
+ is_profile=False,
+ )
+
+ num_local_experts = test_config.num_local_experts
+ num_global_experts = test_config.num_experts
+ for lidx, fml in enumerate(fml_layers):
+ for name, w in fml.named_parameters():
+ for e in range(num_local_experts):
+ g_e = shuffled_indices[lidx][ep_rank * num_local_experts + e]
+ ref = make_expert_weights(
+ layer_idx=lidx,
+ global_expert_idx=int(g_e.item()),
+ global_num_experts=num_global_experts,
+ tensor_shape=w[e].shape,
+ tensor_dtype=w[e].dtype,
+ tensor_device=w[e].device,
+ is_column_major=not w[e].is_contiguous(),
+ )
+ assert w[e].shape == ref.shape and w[e].stride() == ref.stride(), (
+ f"w[{e}] {w[e].size()} {w[e].stride()} vs "
+ f"ref {ref.size()} {ref.stride()}"
+ )
+ torch.testing.assert_close(w[e], ref)
+
+
+@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("num_layers", [4])
+@pytest.mark.parametrize("num_experts", [16])
+@pytest.mark.parametrize("hidden_size", [256])
+@pytest.mark.parametrize("intermediate_size", [256])
+@pytest.mark.parametrize("column_major_scales", [True, False])
+def test_eplb_fml(
+ world_size: int,
+ num_layers: int,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ column_major_scales: bool,
+):
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ num_local_experts = num_experts // world_size
+ num_topk = 4
+ # The dtypes are fine as we are essentially just checking data-copies
+ weight_dtype = torch.bfloat16
+ weight_scale_dtype = torch.bfloat16
+
+ test_config = TestConfig(
+ num_layers=num_layers,
+ num_experts=num_experts,
+ num_local_experts=num_local_experts,
+ num_topk=num_topk,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ weight_dtype=weight_dtype,
+ weight_scale_dtype=weight_scale_dtype,
+ column_major_scales=column_major_scales,
+ )
+
+ distributed_run(
+ _test_eplb_fml,
+ world_size,
+ test_config,
+ )
diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py
index 11e23f128f331..c055b7a3f6dd7 100644
--- a/tests/distributed/test_eplb_spec_decode.py
+++ b/tests/distributed/test_eplb_spec_decode.py
@@ -10,10 +10,11 @@ from tests.utils import large_gpu_mark
def get_model_args(
model_name: str,
- spec_model_name: str,
+ spec_model_name: str | None,
spec_method: str,
tp_size: int,
model_max_len: int,
+ use_async: bool = False,
) -> dict:
speculative_config = {
"method": spec_method,
@@ -37,6 +38,8 @@ def get_model_args(
"enable_eplb": True,
"max_model_len": model_max_len,
}
+ if use_async:
+ model_args["eplb_config"] = {"use_async": True}
return model_args
@@ -94,3 +97,37 @@ def test_eplb_spec_decode(
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
+
+
+@large_gpu_mark(min_gb=80)
+def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
+ """
+ Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
+ """
+
+ TASK = "gsm8k"
+ FILTER = "exact_match,strict-match"
+ RTOL = 0.03
+ expected_gsm8k_value = 0.86
+
+ model_args = get_model_args(
+ model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
+ spec_model_name=None,
+ spec_method="mtp",
+ tp_size=4,
+ model_max_len=4096,
+ use_async=True,
+ )
+
+ results = lm_eval.simple_evaluate(
+ model="vllm",
+ model_args=model_args,
+ tasks=TASK,
+ batch_size=64,
+ num_fewshot=8,
+ )
+ measured_value = results["results"][TASK][FILTER]
+ assert (
+ measured_value - RTOL < expected_gsm8k_value
+ and measured_value + RTOL > expected_gsm8k_value
+ ), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py
index c3085beeb3564..c7c9d0602def0 100644
--- a/tests/distributed/test_pynccl.py
+++ b/tests/distributed/test_pynccl.py
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import multiprocessing
import os
+import multiprocess as mp
import numpy as np
import pytest
import torch
@@ -20,10 +20,12 @@ from vllm.distributed.parallel_state import (
)
from vllm.utils.system_utils import update_environment_variables
+mp.set_start_method("spawn", force=True)
+
def distributed_run(fn, world_size):
number_of_processes = world_size
- processes: list[multiprocessing.Process] = []
+ processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
@@ -32,7 +34,7 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
- p = multiprocessing.Process(target=fn, args=(env,))
+ p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index 472b1487ef440..be926764e4948 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -249,14 +249,13 @@ def test_compilation_config():
args = parser.parse_args(
[
"-O",
- '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
- '"use_inductor": false}',
+ '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
]
)
assert (
args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
- and not args.compilation_config.use_inductor
+ and args.compilation_config.backend == "eager"
)
# set to string form of a dict
@@ -264,13 +263,13 @@ def test_compilation_config():
[
"--compilation-config="
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
- '"use_inductor": true}',
+ '"backend": "inductor"}',
]
)
assert (
args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
- and args.compilation_config.use_inductor
+ and args.compilation_config.backend == "inductor"
)
@@ -278,8 +277,9 @@ def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([])
+ # should be None by default (depends on model).
engine_args = EngineArgs.from_cli_args(args=args)
- assert not engine_args.enable_prefix_caching, "prefix caching defaults to off."
+ assert engine_args.enable_prefix_caching is None
# with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"])
diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py
index 4e7b765d7713f..65a6fd20bd0d1 100644
--- a/tests/entrypoints/openai/test_metrics.py
+++ b/tests/entrypoints/openai/test_metrics.py
@@ -183,9 +183,6 @@ async def test_metrics_counts(
EXPECTED_METRICS_V1 = [
"vllm:num_requests_running",
"vllm:num_requests_waiting",
- "vllm:gpu_cache_usage_perc",
- "vllm:gpu_prefix_cache_queries",
- "vllm:gpu_prefix_cache_hits",
"vllm:kv_cache_usage_perc",
"vllm:prefix_cache_queries",
"vllm:prefix_cache_hits",
diff --git a/tests/entrypoints/openai/test_response_api_simple.py b/tests/entrypoints/openai/test_response_api_simple.py
new file mode 100644
index 0000000000000..425b8199a0fd0
--- /dev/null
+++ b/tests/entrypoints/openai/test_response_api_simple.py
@@ -0,0 +1,71 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+import pytest
+import pytest_asyncio
+from openai import OpenAI
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "Qwen/Qwen3-8B"
+
+
+@pytest.fixture(scope="module")
+def server():
+ args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
+ env_dict = dict(
+ VLLM_ENABLE_RESPONSES_API_STORE="1",
+ # uncomment for tool calling
+ # PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
+ )
+
+ with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
+ yield remote_server
+
+
+@pytest_asyncio.fixture
+async def client(server):
+ async with server.get_async_client() as async_client:
+ yield async_client
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_basic(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ )
+ assert response is not None
+ print("response: ", response)
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_reasoning_item(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {"type": "message", "content": "Hello.", "role": "user"},
+ {
+ "type": "reasoning",
+ "id": "lol",
+ "content": [
+ {
+ "type": "reasoning_text",
+ "text": "We need to respond: greeting.",
+ }
+ ],
+ "summary": [],
+ },
+ ],
+ temperature=0.0,
+ )
+ assert response is not None
+ assert response.status == "completed"
+ # make sure we get a reasoning and text output
+ assert response.output[0].type == "reasoning"
+ assert response.output[1].type == "message"
+ assert type(response.output[1].content[0].text) is str
diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py
index dea8d2d28f61a..8fd3545eccffa 100644
--- a/tests/entrypoints/openai/test_response_api_with_harmony.py
+++ b/tests/entrypoints/openai/test_response_api_with_harmony.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
+import importlib
import json
import time
@@ -35,7 +35,11 @@ GET_WEATHER_SCHEMA = {
@pytest.fixture(scope="module")
def server():
- args = ["--enforce-eager", "--tool-server", "demo"]
+ assert importlib.util.find_spec("gpt_oss") is not None, (
+ "Harmony tests require gpt_oss package to be installed"
+ )
+
+ args = ["--enforce-eager", "--tool-server", "demo", "--max_model_len", "5000"]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
@@ -550,6 +554,31 @@ def call_function(name, args):
raise ValueError(f"Unknown function: {name}")
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_reasoning_item(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {"type": "message", "content": "Hello.", "role": "user"},
+ {
+ "type": "reasoning",
+ "id": "lol",
+ "content": [
+ {
+ "type": "reasoning_text",
+ "text": "We need to respond: greeting.",
+ }
+ ],
+ "summary": [],
+ },
+ ],
+ temperature=0.0,
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling(client: OpenAI, model_name: str):
diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
index 2b68a653f4600..37e52d2cdf609 100644
--- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
+++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from unittest.mock import MagicMock, patch
+
import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
@@ -132,3 +134,129 @@ def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
+
+
+def test_extract_tool_calls_deeply_nested_json(parser):
+ # Test with deeply nested JSON parameters (5 levels)
+ model_output = (
+ '{"name": "complexTool", '
+ '"parameters": {'
+ '"level1": {'
+ '"level2": {'
+ '"level3": {'
+ '"level4": {'
+ '"value": "deep"'
+ "}}}}}}"
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "complexTool"
+ # Verify the nested structure is preserved in the arguments
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
+
+
+def test_extract_tool_calls_multiple_with_deep_nesting(parser):
+ # Test with multiple tool calls where some have deeply nested parameters
+ model_output = (
+ '{"name": "simpleTool", "parameters": {"value": "test"}}; '
+ '{"name": "complexTool", "parameters": '
+ '{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 2
+
+ # Check first tool call
+ assert result.tool_calls[0].function.name == "simpleTool"
+ import json
+
+ args0 = json.loads(result.tool_calls[0].function.arguments)
+ assert args0["value"] == "test"
+
+ # Check second tool call with deep nesting
+ assert result.tool_calls[1].function.name == "complexTool"
+ args1 = json.loads(result.tool_calls[1].function.arguments)
+ assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
+
+
+def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
+ # Test with quotes and brackets inside quoted string values
+ model_output = (
+ '{"name": "searchTool", '
+ '"parameters": {'
+ '"query": "test {value} [complex]",'
+ '"nested": {"inner": "more {brackets}"}'
+ "}}"
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "searchTool"
+ # Verify the string values are preserved including brackets and quotes
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["query"] == "test {value} [complex]"
+ assert args["nested"]["inner"] == "more {brackets}"
+
+
+def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
+ # Test with escaped quotes in deeply nested JSON
+ model_output = (
+ '{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "parserTool"
+ # Verify escaped quotes are preserved
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["text"] == 'He said "Hello {world}"'
+
+
+def test_extract_tool_calls_missing_name_key(parser):
+ # Test that missing "name" key returns content
+ model_output = '{"parameters": {}}'
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ assert result.content == model_output
+
+
+def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
+ # Test that missing both "parameters" and "arguments" keys returns content
+ model_output = '{"name": "toolWithoutParams"}'
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ assert result.content == model_output
+
+
+def test_regex_timeout_handling(parser):
+ """Test regex timeout is handled gracefully"""
+ fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
+
+ # create a mock regex that raises TimeoutError
+ mock_regex = MagicMock()
+ mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
+
+ with patch.object(parser, "tool_call_start_regex", mock_regex):
+ result = parser.extract_tool_calls(fake_problematic_input, None)
+
+ # should treat as regular text when regex times out
+ assert result.content == fake_problematic_input
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ mock_regex.finditer.assert_called_once()
diff --git a/tests/entrypoints/pooling/correctness/__init__.py b/tests/entrypoints/pooling/basic/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/correctness/__init__.py
rename to tests/entrypoints/pooling/basic/__init__.py
diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/basic/test_encode.py
similarity index 92%
rename from tests/entrypoints/pooling/llm/test_encode.py
rename to tests/entrypoints/pooling/basic/test_encode.py
index ca85d2758fce4..f86ecef2e4744 100644
--- a/tests/entrypoints/pooling/llm/test_encode.py
+++ b/tests/entrypoints/pooling/basic/test_encode.py
@@ -7,6 +7,12 @@ import pytest
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "intfloat/multilingual-e5-small"
diff --git a/tests/entrypoints/pooling/openai/test_truncation.py b/tests/entrypoints/pooling/basic/test_truncation.py
similarity index 95%
rename from tests/entrypoints/pooling/openai/test_truncation.py
rename to tests/entrypoints/pooling/basic/test_truncation.py
index 6889628dc9145..0d2d385840402 100644
--- a/tests/entrypoints/pooling/openai/test_truncation.py
+++ b/tests/entrypoints/pooling/basic/test_truncation.py
@@ -7,6 +7,12 @@ import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128
diff --git a/tests/entrypoints/pooling/llm/__init__.py b/tests/entrypoints/pooling/classify/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/__init__.py
rename to tests/entrypoints/pooling/classify/__init__.py
diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/classify/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_classify.py
rename to tests/entrypoints/pooling/classify/test_offline.py
diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/classify/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_classification.py
rename to tests/entrypoints/pooling/classify/test_online.py
diff --git a/tests/entrypoints/pooling/openai/test_vision_classification.py b/tests/entrypoints/pooling/classify/test_online_vision.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_vision_classification.py
rename to tests/entrypoints/pooling/classify/test_online_vision.py
diff --git a/tests/entrypoints/pooling/openai/__init__.py b/tests/entrypoints/pooling/embed/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/__init__.py
rename to tests/entrypoints/pooling/embed/__init__.py
diff --git a/tests/entrypoints/pooling/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/embed/test_correctness_mteb.py
similarity index 87%
rename from tests/entrypoints/pooling/correctness/test_mteb_embed.py
rename to tests/entrypoints/pooling/embed/test_correctness_mteb.py
index 7f16638e51e2c..64673534fd32a 100644
--- a/tests/entrypoints/pooling/correctness/test_mteb_embed.py
+++ b/tests/entrypoints/pooling/embed/test_correctness_mteb.py
@@ -11,6 +11,12 @@ from tests.models.language.pooling_mteb_test.mteb_utils import (
run_mteb_embed_task,
)
from tests.utils import RemoteOpenAIServer
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/embed/test_offline.py
similarity index 90%
rename from tests/entrypoints/pooling/llm/test_embedding.py
rename to tests/entrypoints/pooling/embed/test_offline.py
index 5455b5f91fc09..f5eab4c29ae18 100644
--- a/tests/entrypoints/pooling/llm/test_embedding.py
+++ b/tests/entrypoints/pooling/embed/test_offline.py
@@ -9,6 +9,12 @@ import torch.nn.functional as F
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "intfloat/multilingual-e5-small"
diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/embed/test_online.py
similarity index 99%
rename from tests/entrypoints/pooling/openai/test_embedding.py
rename to tests/entrypoints/pooling/embed/test_online.py
index e971b23e8f1a0..0c88d800e2f99 100644
--- a/tests/entrypoints/pooling/openai/test_embedding.py
+++ b/tests/entrypoints/pooling/embed/test_online.py
@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse,
PoolingResponse,
)
+from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
@@ -28,6 +29,11 @@ from vllm.utils.serial_utils import (
decode_pooling_output,
)
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
+
MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
DTYPE = "bfloat16"
diff --git a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/embed/test_online_dimensions.py
similarity index 95%
rename from tests/entrypoints/pooling/openai/test_embedding_dimensions.py
rename to tests/entrypoints/pooling/embed/test_online_dimensions.py
index ba9fb64262772..8018dac2d3ffe 100644
--- a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py
+++ b/tests/entrypoints/pooling/embed/test_online_dimensions.py
@@ -12,6 +12,12 @@ from tests.models.language.pooling.embed_utils import run_embedding_correctness_
from tests.models.utils import EmbedModelInfo
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODELS = [
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
diff --git a/tests/entrypoints/pooling/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/embed/test_online_long_text.py
similarity index 98%
rename from tests/entrypoints/pooling/openai/test_embedding_long_text.py
rename to tests/entrypoints/pooling/embed/test_online_long_text.py
index f977c81a9084e..a9ade09dad0b5 100644
--- a/tests/entrypoints/pooling/openai/test_embedding_long_text.py
+++ b/tests/entrypoints/pooling/embed/test_online_long_text.py
@@ -16,6 +16,12 @@ import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
def _generate_random_text(word_count: int) -> str:
diff --git a/tests/entrypoints/pooling/openai/test_vision_embedding.py b/tests/entrypoints/pooling/embed/test_online_vision.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_vision_embedding.py
rename to tests/entrypoints/pooling/embed/test_online_vision.py
diff --git a/tests/entrypoints/pooling/pooling/__init__.py b/tests/entrypoints/pooling/pooling/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/openai/test_pooling.py b/tests/entrypoints/pooling/pooling/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_pooling.py
rename to tests/entrypoints/pooling/pooling/test_online.py
diff --git a/tests/entrypoints/pooling/reward/__init__.py b/tests/entrypoints/pooling/reward/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/reward/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_reward.py
rename to tests/entrypoints/pooling/reward/test_offline.py
diff --git a/tests/entrypoints/pooling/score/__init__.py b/tests/entrypoints/pooling/score/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/correctness/test_mteb_score.py b/tests/entrypoints/pooling/score/test_correctness_mteb.py
similarity index 91%
rename from tests/entrypoints/pooling/correctness/test_mteb_score.py
rename to tests/entrypoints/pooling/score/test_correctness_mteb.py
index 1afe68b189db8..81ad0097187b0 100644
--- a/tests/entrypoints/pooling/correctness/test_mteb_score.py
+++ b/tests/entrypoints/pooling/score/test_correctness_mteb.py
@@ -13,6 +13,12 @@ from tests.models.language.pooling_mteb_test.mteb_utils import (
run_mteb_rerank,
)
from tests.utils import RemoteOpenAIServer
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/score/test_offline.py
similarity index 90%
rename from tests/entrypoints/pooling/llm/test_score.py
rename to tests/entrypoints/pooling/score/test_offline.py
index b69c6a47c1913..ce36d61cb8476 100644
--- a/tests/entrypoints/pooling/llm/test_score.py
+++ b/tests/entrypoints/pooling/score/test_offline.py
@@ -9,6 +9,12 @@ import torch
from tests.models.utils import softmax
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/score/test_online_rerank.py
similarity index 97%
rename from tests/entrypoints/pooling/openai/test_rerank.py
rename to tests/entrypoints/pooling/score/test_online_rerank.py
index 1d85190c12a19..5a772e22a7414 100644
--- a/tests/entrypoints/pooling/openai/test_rerank.py
+++ b/tests/entrypoints/pooling/score/test_online_rerank.py
@@ -8,6 +8,12 @@ import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
diff --git a/tests/entrypoints/pooling/openai/test_score.py b/tests/entrypoints/pooling/score/test_online_score.py
similarity index 97%
rename from tests/entrypoints/pooling/openai/test_score.py
rename to tests/entrypoints/pooling/score/test_online_score.py
index b8f796d47efaa..ceff9d0181825 100644
--- a/tests/entrypoints/pooling/openai/test_score.py
+++ b/tests/entrypoints/pooling/score/test_online_score.py
@@ -10,6 +10,12 @@ from torch import tensor
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ScoreResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODELS = [
{"name": "BAAI/bge-reranker-v2-m3", "is_cross_encoder": True},
diff --git a/tests/entrypoints/test_responses_utils.py b/tests/entrypoints/test_responses_utils.py
index 48bf06088bc05..893d806b65742 100644
--- a/tests/entrypoints/test_responses_utils.py
+++ b/tests/entrypoints/test_responses_utils.py
@@ -1,7 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+from openai.types.responses.response_function_tool_call_output_item import (
+ ResponseFunctionToolCallOutputItem,
+)
+from openai.types.responses.response_reasoning_item import (
+ Content,
+ ResponseReasoningItem,
+ Summary,
+)
+
from vllm.entrypoints.responses_utils import (
+ construct_chat_message_with_tool_call,
convert_tool_responses_to_completions_format,
)
@@ -28,3 +39,65 @@ class TestResponsesUtils:
result = convert_tool_responses_to_completions_format(input_tool)
assert result == {"type": "function", "function": input_tool}
+
+ def test_construct_chat_message_with_tool_call(self):
+ item = ResponseReasoningItem(
+ id="lol",
+ summary=[],
+ type="reasoning",
+ content=[
+ Content(
+ text="Leroy Jenkins",
+ type="reasoning_text",
+ )
+ ],
+ encrypted_content=None,
+ status=None,
+ )
+ formatted_item = construct_chat_message_with_tool_call(item)
+ assert formatted_item["role"] == "assistant"
+ assert formatted_item["reasoning"] == "Leroy Jenkins"
+
+ item = ResponseReasoningItem(
+ id="lol",
+ summary=[
+ Summary(
+ text='Hmm, the user has just started with a simple "Hello,"',
+ type="summary_text",
+ )
+ ],
+ type="reasoning",
+ content=None,
+ encrypted_content=None,
+ status=None,
+ )
+
+ formatted_item = construct_chat_message_with_tool_call(item)
+ assert formatted_item["role"] == "assistant"
+ assert (
+ formatted_item["reasoning"]
+ == 'Hmm, the user has just started with a simple "Hello,"'
+ )
+
+ tool_call_output = ResponseFunctionToolCallOutputItem(
+ id="temp_id",
+ type="function_call_output",
+ call_id="temp",
+ output="1234",
+ status="completed",
+ )
+ formatted_item = construct_chat_message_with_tool_call(tool_call_output)
+ assert formatted_item["role"] == "tool"
+ assert formatted_item["content"] == "1234"
+ assert formatted_item["tool_call_id"] == "temp"
+
+ item = ResponseReasoningItem(
+ id="lol",
+ summary=[],
+ type="reasoning",
+ content=None,
+ encrypted_content="TOP_SECRET_MESSAGE",
+ status=None,
+ )
+ with pytest.raises(ValueError):
+ construct_chat_message_with_tool_call(item)
diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py
index 9662e73321ebe..1a7d5ce0ddc1e 100644
--- a/tests/kernels/attention/test_attention.py
+++ b/tests/kernels/attention/test_attention.py
@@ -13,12 +13,6 @@ from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes
-if not current_platform.is_rocm():
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
-
- from tests.kernels.utils import make_alibi_bias
-
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
@@ -448,129 +442,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
-@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("head_size", HEAD_SIZES)
-@pytest.mark.parametrize("dtype", DTYPES)
-@pytest.mark.parametrize("seed", SEEDS)
-@pytest.mark.parametrize("device", CUDA_DEVICES)
-@pytest.mark.skipif(
- current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
-)
-@torch.inference_mode()
-def test_multi_query_kv_attention(
- num_seqs: int,
- num_heads: tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: str,
- use_alibi: bool = False,
-) -> None:
- current_platform.seed_everything(seed)
- torch.set_default_device(device)
- # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
- # As the xformers library is already tested with its own tests, we can use
- # a smaller MAX_SEQ_LEN here.
- max_len = min(MAX_SEQ_LEN, 4096)
- seq_lens = random.sample(range(1, max_len), num_seqs)
- num_tokens = sum(seq_lens)
-
- scale = float(1.0 / (head_size**0.5))
- num_query_heads, num_kv_heads = num_heads
- qkv = torch.empty(
- num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
- )
- qkv.uniform_(-scale, scale)
- query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
-
- num_queries_per_kv = num_query_heads // num_kv_heads
- if num_queries_per_kv > 1:
- # Handle MQA and GQA
- key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
- value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
- alibi_bias = None
- if use_alibi:
- alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
- attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
- output = torch.empty_like(query)
- start = 0
- # Dynamic sequence length not supported with custom attn_bias.
- for i, seq_len in enumerate(seq_lens):
- end = start + seq_len
- out = xops.memory_efficient_attention_forward(
- query[None, start:end],
- key[None, start:end],
- value[None, start:end],
- attn_bias=attn_bias[i],
- p=0.0,
- scale=scale,
- )
- output[start:end].copy_(out.view_as(query[start:end]))
- start += seq_len
- # xformers.AttentionBias to Tensor for use in reference impl.
- alibi_bias = [
- b.materialize((1, num_query_heads, i, i), device=device).squeeze()
- for b, i in zip(attn_bias, seq_lens)
- ]
- else:
- attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
- output = xops.memory_efficient_attention_forward(
- query.unsqueeze(0),
- key.unsqueeze(0),
- value.unsqueeze(0),
- attn_bias=attn_bias,
- p=0.0,
- scale=scale,
- )
- output = output.squeeze(0)
-
- cu_seq_lens = [0]
- for seq_len in seq_lens:
- cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
- ref_output = ref_multi_query_kv_attention(
- cu_seq_lens,
- query,
- key,
- value,
- scale,
- alibi_bias,
- dtype,
- )
- atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
- rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
- torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
-
-
-@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("head_size", [64])
-@pytest.mark.parametrize("dtype", DTYPES)
-@pytest.mark.parametrize("seed", SEEDS)
-@pytest.mark.parametrize("device", CUDA_DEVICES)
-@pytest.mark.skipif(
- current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
-)
-@torch.inference_mode()
-def test_multi_query_kv_attention_with_alibi(
- num_seqs: int,
- num_heads: tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: str,
-) -> None:
- return test_multi_query_kv_attention(
- num_seqs,
- num_heads,
- head_size,
- dtype,
- seed,
- device,
- use_alibi=True,
- )
-
-
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py
index 9be56a33f76c8..cd34b520ea71b 100644
--- a/tests/kernels/attention/test_attention_selector.py
+++ b/tests/kernels/attention/test_attention_selector.py
@@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
}
DEVICE_REGULAR_ATTN_BACKENDS = {
- "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
+ "cuda": ["FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"],
"cpu": ["CPU_ATTN"],
}
@@ -207,12 +207,6 @@ def test_env(
)
expected = "FLASHINFER"
assert backend.get_name() == expected
- elif name == "XFORMERS":
- backend = get_attn_backend(
- 32, torch.float16, None, block_size, use_mla=use_mla
- )
- expected = "XFORMERS"
- assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py
index 028e164cb801b..acf46d75d62eb 100644
--- a/tests/kernels/attention/test_cache.py
+++ b/tests/kernels/attention/test_cache.py
@@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
- seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
+ seq_len_tensor = torch.randint(
+ max_seq_len, max_seq_len + 1, (batch_size,), device=device
+ )
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
+ token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
@@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ total_tokens,
kv_cache_dtype,
scale,
None,
@@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ total_tokens,
kv_cache_dtype,
scale,
None,
diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py
index a878ac6396ce5..ae3c63cc62d6b 100644
--- a/tests/kernels/attention/test_mha_attn.py
+++ b/tests/kernels/attention/test_mha_attn.py
@@ -24,10 +24,6 @@ from vllm.platforms.rocm import RocmPlatform
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
- # Clear xformers availability cache
- import vllm.attention.layer as layer_module
-
- layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py
index 2285709fa7d60..dab1207d78031 100644
--- a/tests/kernels/moe/test_batched_moe.py
+++ b/tests/kernels/moe/test_batched_moe.py
@@ -39,6 +39,11 @@ MNK_FACTORS = [
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
+DTYPES = [torch.bfloat16]
+
+if not current_platform.is_fp8_fnuz():
+ DTYPES.append(torch.float8_e4m3fn)
+
vllm_config = VllmConfig()
@@ -96,7 +101,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
-@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(
@@ -229,7 +234,7 @@ def test_batched_mm(
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
-@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py
index 88db4b3e537c2..b0ff1e64e3219 100644
--- a/tests/kernels/moe/test_block_fp8.py
+++ b/tests/kernels/moe/test_block_fp8.py
@@ -31,6 +31,11 @@ dg_available = has_deep_gemm()
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
vllm_config = VllmConfig()
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index 638741e91619b..a6977f222408d 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
-from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
@@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
- router_logits=score,
- use_grouped_topk=False,
- top_k=topk,
+ gating_output=score,
+ topk=topk,
renormalize=False,
- custom_routing_function=Llama4MoE.custom_routing_function,
- scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
@@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
- router_logits=score,
- use_grouped_topk=False,
- top_k=topk,
+ gating_output=score,
+ topk=topk,
renormalize=False,
- custom_routing_function=Llama4MoE.custom_routing_function,
- scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py
index af33fd4e3fc3b..98e80ec029777 100644
--- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py
+++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py
@@ -270,6 +270,11 @@ class Case:
@pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp):
+ from triton_kernels.tensor_details import layout
+
+ if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
+ pytest.skip("make_default_matmul_mxfp4_w_layout not available")
+
M = num_token
E = ModelConfig.num_experts
K = ModelConfig.hidden_size
diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py
index e3b8621b452fa..2a30ef2355529 100644
--- a/tests/kernels/moe/test_modular_kernel_combinations.py
+++ b/tests/kernels/moe/test_modular_kernel_combinations.py
@@ -46,6 +46,12 @@ meets_multi_gpu_requirements = pytest.mark.skipif(
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
)
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
+
def format_result(verbose, msg, ex=None):
if ex is not None:
diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py
index ba1f657b3ecda..12dd322dccc52 100644
--- a/tests/kernels/moe/test_moe_permute_unpermute.py
+++ b/tests/kernels/moe/test_moe_permute_unpermute.py
@@ -23,6 +23,12 @@ TOP_KS = [2, 6, 8]
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)
+if current_platform.is_rocm():
+ pytest.skip(
+ "moe_permute_unpermute_supported is not defined for ROCm",
+ allow_module_level=True,
+ )
+
def torch_permute(
hidden_states: torch.Tensor,
diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
index d6b78dd2c2323..b220205759e2d 100644
--- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
+++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
@@ -14,6 +14,12 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
+
fp8_dtype = torch.float8_e4m3fn
CASES = [
diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
index 7a467e160b784..0ab025dceca40 100644
--- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
@@ -19,6 +19,12 @@ if current_platform.get_device_capability() < (9, 0):
vllm_config = VllmConfig()
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
+
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input
diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py
index 5d5a26fbfc2cd..9307ef7814a8b 100644
--- a/tests/kernels/utils.py
+++ b/tests/kernels/utils.py
@@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
)
-def make_alibi_bias(
- alibi_slopes: torch.Tensor,
- num_kv_heads: int,
- dtype: torch.dtype,
- seq_lens: list[int],
-) -> list[Any]:
- """Create ALiBi biases compatible with xFormers attention tests."""
- from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
-
- if alibi_slopes is None:
- return [None for _ in seq_lens]
-
- attn_biases: list[Any] = []
- num_heads = alibi_slopes.shape[0]
- assert num_heads >= num_kv_heads, (
- "ALiBi slopes expect at least as many heads as KV heads"
- )
-
- for seq_len in seq_lens:
- bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
- bias = bias[None, :] - bias[:, None]
-
- padded_len = (seq_len + 7) // 8 * 8
- bias_tensor = torch.empty(
- 1,
- num_heads,
- seq_len,
- padded_len,
- device=alibi_slopes.device,
- dtype=dtype,
- )[:, :, :, :seq_len].copy_(bias)
- bias_tensor.mul_(alibi_slopes[:, None, None])
- attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
-
- return attn_biases
-
-
def _make_metadata_tensors(
seq_lens: list[int] | None,
context_lens: list[int] | None,
@@ -649,23 +612,12 @@ def make_kv_cache(
Returns:
- * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
- * for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
"""
- if backend == "XFORMERS":
- kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
- device
- )
- elif backend == "FLASH_ATTN":
- kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
- device
- )
- else:
- raise ValueError(
- f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
- )
+ if backend != "FLASH_ATTN":
+ raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
+ kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@@ -843,22 +795,14 @@ def assert_actual_matches_ideal(
* output_under_test: actually observed output value
"""
ideal_output = test_params.packed_qkvo.ideal_output
- if backend == "XFORMERS":
- torch.testing.assert_close(
- ideal_output, output_under_test.view_as(ideal_output)
- )
-
- elif backend == "FLASH_ATTN":
- # For FlashAttention override the accuracy thresholds to non default
- # values since we notice a higher difference between the ideal and
- # actual output.
- torch.testing.assert_close(
- ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
- )
- else:
- raise ValueError(
- f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
- )
+ if backend != "FLASH_ATTN":
+ raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
+ # For FlashAttention override the accuracy thresholds to non default
+ # values since we notice a higher difference between the ideal and
+ # actual output.
+ torch.testing.assert_close(
+ ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
+ )
# Copied/modified from torch._refs.__init__.py
diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py
index 711d514a39eb3..f4269750feb6b 100644
--- a/tests/lora/test_gptoss_tp.py
+++ b/tests/lora/test_gptoss_tp.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+
import vllm
from vllm.lora.request import LoRARequest
@@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
@multi_gpu_test(num_gpus=2)
-def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
+@pytest.mark.parametrize("fully_sharded_loras", [False, True])
+def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
+ max_num_seqs=16,
tensor_parallel_size=2,
+ fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py
index 1cf8ed602b6a4..e430826461a14 100644
--- a/tests/lora/test_minicpmv_tp.py
+++ b/tests/lora/test_minicpmv_tp.py
@@ -57,10 +57,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
return generated_texts
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
@@ -84,10 +80,6 @@ def test_minicpmv_lora(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
@@ -108,10 +100,6 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py
index 1800ca107a426..7d8c940100ca4 100644
--- a/tests/lora/test_qwen2vl.py
+++ b/tests/lora/test_qwen2vl.py
@@ -2,12 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
-import pytest
-
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
-from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
@@ -142,10 +139,6 @@ QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2-VL dependency xformers incompatible with ROCm",
-)
def test_qwen2vl_lora(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA"""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@@ -156,10 +149,6 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2-VL dependency xformers incompatible with ROCm",
-)
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA through beam search."""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@@ -178,10 +167,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
-)
def test_qwen25vl_lora(qwen25vl_lora_files):
"""Test Qwen 2.5 VL model with LoRA"""
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)
diff --git a/tests/model_executor/test_qwen3_omni.py b/tests/model_executor/test_qwen3_omni.py
new file mode 100644
index 0000000000000..c92c61dcd3bc2
--- /dev/null
+++ b/tests/model_executor/test_qwen3_omni.py
@@ -0,0 +1,221 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from unittest.mock import Mock
+
+import pytest
+from transformers import PretrainedConfig
+
+from vllm.multimodal.processing import InputProcessingContext
+
+
+# Helper function to print input IDs with coalesced audio/video tokens.
+def print_input_ids(input_ids):
+ """
+ Print input IDs, compressing consecutive special tokens.
+ - 151675: <|audio_pad|>
+ - 151656: <|video_pad|>
+ """
+ if not input_ids:
+ print("[]")
+ return
+
+ result = []
+ i = 0
+
+ while i < len(input_ids):
+ current_id = input_ids[i]
+
+ # Check if it's a special token that should be compressed
+ if current_id in [151675, 151656]:
+ # Count consecutive occurrences
+ count = 1
+ while i + count < len(input_ids) and input_ids[i + count] == current_id:
+ count += 1
+
+ # Add compressed representation
+ token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
+ result.append(f"{token_name} * {count}")
+ i += count
+ else:
+ # Regular token, just add it
+ result.append(str(current_id))
+ i += 1
+
+ print(", ".join(result))
+
+
+@pytest.fixture
+def mock_qwen3_omni_config():
+ """Create a mock Qwen3OmniMoeThinker config."""
+ config = Mock(spec=PretrainedConfig)
+ # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
+ config.audio_token_id = 151675 # <|audio_pad|>
+ config.video_token_id = 151656 # <|video_pad|>
+ config.image_token_id = 151655 # <|image_pad|>
+ config.audio_start_token_id = 151669 # <|audio_start|>
+ config.audio_end_token_id = 151670 # <|audio_end|>
+ config.vision_start_token_id = 151652 # <|vision_start|>
+ config.position_id_per_seconds = 12.5
+
+ # Vision config
+ vision_config = Mock()
+ vision_config.spatial_merge_size = 2
+ config.vision_config = vision_config
+
+ return config
+
+
+@pytest.fixture
+def mock_processor():
+ """Create a mock HF processor."""
+ from transformers.models.whisper import WhisperFeatureExtractor
+
+ processor = Mock()
+ processor.audio_token = "<|audio_pad|>"
+ processor.image_token = "<|image_pad|>"
+ processor.video_token = "<|video_pad|>"
+
+ # Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
+ feature_extractor = WhisperFeatureExtractor()
+ processor.feature_extractor = feature_extractor
+
+ return processor
+
+
+@pytest.fixture
+def mock_tokenizer():
+ """Create a mock tokenizer."""
+ tokenizer = Mock()
+ # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
+ tokenizer.get_vocab = Mock(
+ return_value={
+ "<|audio_pad|>": 151675,
+ "<|video_pad|>": 151656,
+ "<|image_pad|>": 151655,
+ "<|audio_start|>": 151669,
+ "<|audio_end|>": 151670,
+ "<|vision_start|>": 151652,
+ "<|vision_end|>": 151653,
+ }
+ )
+ tokenizer.encode = Mock(
+ side_effect=lambda x: {
+ "<|vision_start|>": [151652],
+ "<|vision_end|>": [151653],
+ "<|audio_start|>": [151669],
+ "<|audio_end|>": [151670],
+ "<|audio_pad|>": [151675],
+ "<|image_pad|>": [151655],
+ "<|video_pad|>": [151656],
+ }.get(x, [0])
+ )
+ tokenizer.vision_bos_token = "<|vision_start|>"
+ tokenizer.vision_eos_token = "<|vision_end|>"
+ tokenizer.audio_bos_token = "<|audio_start|>"
+ tokenizer.audio_eos_token = "<|audio_end|>"
+ return tokenizer
+
+
+@pytest.fixture
+def mock_image_processor():
+ """Create a mock image processor."""
+ image_processor = Mock()
+ image_processor.merge_size = 2
+ return image_processor
+
+
+def test_qwen3_omni_get_updates_use_audio_in_video(
+ mock_qwen3_omni_config,
+ mock_processor,
+ mock_tokenizer,
+ mock_image_processor,
+):
+ """Test the get_updates_use_audio_in_video method directly."""
+
+ from vllm.model_executor.models.qwen3_omni_moe_thinker import (
+ Qwen3OmniMoeThinkerMultiModalProcessor,
+ Qwen3OmniMoeThinkerProcessingInfo,
+ )
+
+ # Create a mock context
+ mock_ctx = Mock(spec=InputProcessingContext)
+
+ # Create processing info
+ info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
+ info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
+ info.get_hf_processor = Mock(return_value=mock_processor)
+ info.get_tokenizer = Mock(return_value=mock_tokenizer)
+ info.get_image_processor = Mock(return_value=mock_image_processor)
+
+ # Create a mock dummy_inputs builder
+ mock_dummy_inputs = Mock()
+
+ # Create the processor
+ processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
+
+ # Test parameters from reference video
+ # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
+ audio_len = 85
+ video_grid_thw = [6, 36, 64]
+ video_second_per_grid_t = 2.0
+
+ # Call the method
+ updates = processor.get_updates_use_audio_in_video(
+ thinker_config=mock_qwen3_omni_config,
+ audio_len=audio_len,
+ video_grid_thw=video_grid_thw,
+ video_second_per_grid_t=video_second_per_grid_t,
+ )
+
+ # Updated input ids should align with HF implementation.
+ # 151669,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 10,
+ # <|video_pad|> * 1152,
+ # 151670
+ print_input_ids(updates)
+
+ # Verify structure
+ assert isinstance(updates, list)
+ assert len(updates) > 0
+
+ # Verify start and end tokens
+ audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
+ audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
+
+ assert updates[0] == audio_start_token_id
+ assert updates[-1] == audio_end_token_id
+
+ # Verify both audio and video tokens are present
+ audio_token_id = mock_qwen3_omni_config.audio_token_id
+ video_token_id = mock_qwen3_omni_config.video_token_id
+
+ audio_count = updates.count(audio_token_id)
+ video_count = updates.count(video_token_id)
+
+ assert audio_count == audio_len, (
+ f"Expected {audio_len} audio tokens, got {audio_count}"
+ )
+
+ # Calculate expected video token count
+ spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
+ height = video_grid_thw[1] // spatial_merge_size
+ width = video_grid_thw[2] // spatial_merge_size
+ expected_video_count = video_grid_thw[0] * height * width
+
+ assert video_count == expected_video_count, (
+ f"Expected {expected_video_count} video tokens, got {video_count}"
+ )
+
+ # Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
+ expected_total = 1 + audio_len + expected_video_count + 1
+ assert len(updates) == expected_total, (
+ f"Expected {expected_total} total tokens, got {len(updates)}"
+ )
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py
index 0cdb7c9a603f2..df6c2cab7814b 100644
--- a/tests/models/language/generation/test_common.py
+++ b/tests/models/language/generation/test_common.py
@@ -10,13 +10,6 @@ from ....utils import large_gpu_mark
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
-# These have unsupported head_dim for FA. We do not
-# have a clean way to fall back, so we fail with
-# a clear msg when it happens.
-# https://github.com/vllm-project/vllm/issues/14524
-# NOTE(woosuk): Skipping these tests until V1 supports them.
-# REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
-
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
diff --git a/tests/models/registry.py b/tests/models/registry.py
index b088e16756d7a..f8b3470e6d39b 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -626,6 +626,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True,
),
+ "HunYuanVLForConditionalGeneration": _HfExamplesInfo(
+ "tencent/HunyuanOCR",
+ is_available_online=False,
+ ),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
@@ -725,6 +729,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"NemotronH_Nano_VL_V2": _HfExamplesInfo(
"nano_vl_dummy", is_available_online=False, trust_remote_code=True
),
+ "OpenCUAForConditionalGeneration": _HfExamplesInfo(
+ "xlangai/OpenCUA-7B", trust_remote_code=True
+ ),
"Ovis": _HfExamplesInfo(
"AIDC-AI/Ovis2-1B",
trust_remote_code=True,
diff --git a/tests/models/test_gguf_download.py b/tests/models/test_gguf_download.py
new file mode 100644
index 0000000000000..155768ac9bff7
--- /dev/null
+++ b/tests/models/test_gguf_download.py
@@ -0,0 +1,240 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from vllm.config import ModelConfig
+from vllm.config.load import LoadConfig
+from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
+from vllm.model_executor.model_loader.weight_utils import download_gguf
+
+
+class TestGGUFDownload:
+ """Test GGUF model downloading functionality."""
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ def test_download_gguf_single_file(self, mock_download):
+ """Test downloading a single GGUF file."""
+ # Setup mock
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ # Mock glob to return a single file
+ with patch("glob.glob") as mock_glob:
+ mock_glob.side_effect = lambda pattern, **kwargs: (
+ [f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
+ )
+
+ result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
+
+ # Verify download_weights_from_hf was called with correct patterns
+ mock_download.assert_called_once_with(
+ model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
+ cache_dir=None,
+ allow_patterns=[
+ "*-IQ1_S.gguf",
+ "*-IQ1_S-*.gguf",
+ "*/*-IQ1_S.gguf",
+ "*/*-IQ1_S-*.gguf",
+ ],
+ revision=None,
+ ignore_patterns=None,
+ )
+
+ # Verify result is the file path, not folder
+ assert result == f"{mock_folder}/model-IQ1_S.gguf"
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ def test_download_gguf_sharded_files(self, mock_download):
+ """Test downloading sharded GGUF files."""
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ # Mock glob to return sharded files
+ with patch("glob.glob") as mock_glob:
+ mock_glob.side_effect = lambda pattern, **kwargs: (
+ [
+ f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
+ f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
+ ]
+ if "Q2_K" in pattern
+ else []
+ )
+
+ result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
+
+ # Should return the first file after sorting
+ assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ def test_download_gguf_subdir(self, mock_download):
+ """Test downloading GGUF files from subdirectory."""
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ with patch("glob.glob") as mock_glob:
+ mock_glob.side_effect = lambda pattern, **kwargs: (
+ [f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
+ if "Q2_K" in pattern or "**/*.gguf" in pattern
+ else []
+ )
+
+ result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
+
+ assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ @patch("glob.glob", return_value=[])
+ def test_download_gguf_no_files_found(self, mock_glob, mock_download):
+ """Test error when no GGUF files are found."""
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
+ download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
+
+
+class TestGGUFModelLoader:
+ """Test GGUFModelLoader class methods."""
+
+ @patch("os.path.isfile", return_value=True)
+ def test_prepare_weights_local_file(self, mock_isfile):
+ """Test _prepare_weights with local file."""
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ # Create a simple mock ModelConfig with only the model attribute
+ model_config = MagicMock()
+ model_config.model = "/path/to/model.gguf"
+
+ result = loader._prepare_weights(model_config)
+ assert result == "/path/to/model.gguf"
+ mock_isfile.assert_called_once_with("/path/to/model.gguf")
+
+ @patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
+ """Test _prepare_weights with HTTPS URL."""
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ mock_hf_download.return_value = "/downloaded/model.gguf"
+
+ # Create a simple mock ModelConfig with only the model attribute
+ model_config = MagicMock()
+ model_config.model = "https://huggingface.co/model.gguf"
+
+ result = loader._prepare_weights(model_config)
+ assert result == "/downloaded/model.gguf"
+ mock_hf_download.assert_called_once_with(
+ url="https://huggingface.co/model.gguf"
+ )
+
+ @patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
+ """Test _prepare_weights with repo_id/filename.gguf format."""
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ mock_hf_download.return_value = "/downloaded/model.gguf"
+
+ # Create a simple mock ModelConfig with only the model attribute
+ model_config = MagicMock()
+ model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
+
+ result = loader._prepare_weights(model_config)
+ assert result == "/downloaded/model.gguf"
+ mock_hf_download.assert_called_once_with(
+ repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
+ )
+
+ @patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
+ @patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
+ @patch("vllm.config.model.get_config")
+ @patch("vllm.config.model.is_gguf", return_value=True)
+ @patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_repo_quant_type(
+ self,
+ mock_isfile,
+ mock_download_gguf,
+ mock_is_gguf,
+ mock_get_config,
+ mock_file_exists,
+ mock_get_image_config,
+ ):
+ """Test _prepare_weights with repo_id:quant_type format."""
+ mock_hf_config = MagicMock()
+ mock_hf_config.architectures = ["Qwen3ForCausalLM"]
+
+ class MockTextConfig:
+ max_position_embeddings = 4096
+ sliding_window = None
+ model_type = "qwen3"
+ num_attention_heads = 32
+
+ mock_text_config = MockTextConfig()
+ mock_hf_config.get_text_config.return_value = mock_text_config
+ mock_hf_config.dtype = "bfloat16"
+ mock_get_config.return_value = mock_hf_config
+
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
+
+ model_config = ModelConfig(
+ model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
+ )
+ result = loader._prepare_weights(model_config)
+ # The actual result will be the downloaded file path from mock
+ assert result == "/downloaded/model-IQ1_S.gguf"
+ mock_download_gguf.assert_called_once_with(
+ "unsloth/Qwen3-0.6B-GGUF",
+ "IQ1_S",
+ cache_dir=None,
+ revision=None,
+ ignore_patterns=["original/**/*"],
+ )
+
+ @patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
+ @patch("vllm.config.model.get_config")
+ @patch("vllm.config.model.is_gguf", return_value=False)
+ @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_invalid_format(
+ self,
+ mock_isfile,
+ mock_check_gguf,
+ mock_is_gguf,
+ mock_get_config,
+ mock_get_image_config,
+ ):
+ """Test _prepare_weights with invalid format."""
+ mock_hf_config = MagicMock()
+ mock_hf_config.architectures = ["Qwen3ForCausalLM"]
+
+ class MockTextConfig:
+ max_position_embeddings = 4096
+ sliding_window = None
+ model_type = "qwen3"
+ num_attention_heads = 32
+
+ mock_text_config = MockTextConfig()
+ mock_hf_config.get_text_config.return_value = mock_text_config
+ mock_hf_config.dtype = "bfloat16"
+ mock_get_config.return_value = mock_hf_config
+
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ # Create ModelConfig with a valid repo_id to avoid validation errors
+ # Then test _prepare_weights with invalid format
+ model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
+ # Manually set model to invalid format after creation
+ model_config.model = "invalid-format"
+ with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
+ loader._prepare_weights(model_config)
diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py
index 5a162fa8f791b..e8826eb441a24 100644
--- a/tests/test_routing_simulator.py
+++ b/tests/test_routing_simulator.py
@@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including
integration tests with FusedMoE layer.
"""
+import tempfile
+
import pytest
import torch
+from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.distributed import (
+ init_distributed_environment,
+ initialize_model_parallel,
+)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
DistributionBasedRouting,
RoutingSimulator,
@@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device):
# Test different routing strategies
strategies = RoutingSimulator.get_available_strategies()
+ vllm_config = VllmConfig()
+ with set_current_vllm_config(vllm_config):
+ temp_file = tempfile.mkstemp()[1]
+ init_distributed_environment(
+ world_size=1,
+ rank=0,
+ local_rank=0,
+ distributed_init_method=f"file://{temp_file}",
+ )
+ initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ )
+ fused_moe = FusedMoE(
+ num_experts=num_experts,
+ top_k=top_k,
+ hidden_size=hidden_size,
+ intermediate_size=0,
+ use_grouped_topk=False,
+ renormalize=True,
+ )
+
for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
@@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = fused_moe.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
- top_k=top_k,
- use_grouped_topk=False,
- renormalize=True,
- indices_type=torch.long,
)
# Verify output shapes
diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py
index 9af94a6a64a25..77084ec2d9456 100644
--- a/tests/tool_use/test_parallel_tool_calls.py
+++ b/tests/tool_use/test_parallel_tool_calls.py
@@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results(
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
+
+
+@pytest.mark.asyncio
+async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
+ """
+ Ensure only one tool call is returned when parallel_tool_calls is False.
+ """
+
+ models = await client.models.list()
+ model_name: str = models.data[0].id
+ chat_completion = await client.chat.completions.create(
+ messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
+ temperature=0,
+ max_completion_tokens=200,
+ model=model_name,
+ tools=[WEATHER_TOOL, SEARCH_TOOL],
+ logprobs=False,
+ parallel_tool_calls=False,
+ )
+
+ stop_reason = chat_completion.choices[0].finish_reason
+ non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
+
+ # make sure only 1 tool call is present
+ assert len(non_streamed_tool_calls) == 1
+ assert stop_reason == "tool_calls"
+
+ # make the same request, streaming
+ stream = await client.chat.completions.create(
+ model=model_name,
+ messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
+ temperature=0,
+ max_completion_tokens=200,
+ tools=[WEATHER_TOOL, SEARCH_TOOL],
+ logprobs=False,
+ parallel_tool_calls=False,
+ stream=True,
+ )
+
+ finish_reason_count: int = 0
+ tool_call_id_count: int = 0
+
+ async for chunk in stream:
+ # if there's a finish reason make sure it's tools
+ if chunk.choices[0].finish_reason:
+ finish_reason_count += 1
+ assert chunk.choices[0].finish_reason == "tool_calls"
+
+ streamed_tool_calls = chunk.choices[0].delta.tool_calls
+ if streamed_tool_calls and len(streamed_tool_calls) > 0:
+ tool_call = streamed_tool_calls[0]
+ if tool_call.id:
+ tool_call_id_count += 1
+
+ # make sure only 1 streaming tool call is present
+ assert tool_call_id_count == 1
+ assert finish_reason_count == 1
diff --git a/tests/transformers_utils/test_utils.py b/tests/transformers_utils/test_utils.py
index bfe1cec76c138..a8d0b9be9ec29 100644
--- a/tests/transformers_utils/test_utils.py
+++ b/tests/transformers_utils/test_utils.py
@@ -1,11 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from pathlib import Path
+from unittest.mock import patch
+import pytest
from vllm.transformers_utils.utils import (
is_cloud_storage,
is_gcs,
+ is_gguf,
+ is_remote_gguf,
is_s3,
+ split_remote_gguf,
)
@@ -28,3 +34,143 @@ def test_is_cloud_storage():
assert is_cloud_storage("s3://model-path/path-to-model")
assert not is_cloud_storage("/unix/local/path")
assert not is_cloud_storage("nfs://nfs-fqdn.local")
+
+
+class TestIsRemoteGGUF:
+ """Test is_remote_gguf utility function."""
+
+ def test_is_remote_gguf_with_colon_and_slash(self):
+ """Test is_remote_gguf with repo_id:quant_type format."""
+ # Valid quant types
+ assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
+ assert is_remote_gguf("user/repo:Q2_K")
+ assert is_remote_gguf("repo/model:Q4_K")
+ assert is_remote_gguf("repo/model:Q8_0")
+
+ # Invalid quant types should return False
+ assert not is_remote_gguf("repo/model:quant")
+ assert not is_remote_gguf("repo/model:INVALID")
+ assert not is_remote_gguf("repo/model:invalid_type")
+
+ def test_is_remote_gguf_without_colon(self):
+ """Test is_remote_gguf without colon."""
+ assert not is_remote_gguf("repo/model")
+ assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
+
+ def test_is_remote_gguf_without_slash(self):
+ """Test is_remote_gguf without slash."""
+ assert not is_remote_gguf("model.gguf")
+ # Even with valid quant_type, no slash means not remote GGUF
+ assert not is_remote_gguf("model:IQ1_S")
+ assert not is_remote_gguf("model:quant")
+
+ def test_is_remote_gguf_local_path(self):
+ """Test is_remote_gguf with local file path."""
+ assert not is_remote_gguf("/path/to/model.gguf")
+ assert not is_remote_gguf("./model.gguf")
+
+ def test_is_remote_gguf_with_path_object(self):
+ """Test is_remote_gguf with Path object."""
+ assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
+ assert not is_remote_gguf(Path("repo/model"))
+
+ def test_is_remote_gguf_with_http_https(self):
+ """Test is_remote_gguf with HTTP/HTTPS URLs."""
+ # HTTP/HTTPS URLs should return False even with valid quant_type
+ assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
+ assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
+ assert not is_remote_gguf("http://repo/model:Q4_K")
+ assert not is_remote_gguf("https://repo/model:Q8_0")
+
+ def test_is_remote_gguf_with_cloud_storage(self):
+ """Test is_remote_gguf with cloud storage paths."""
+ # Cloud storage paths should return False even with valid quant_type
+ assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
+ assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
+ assert not is_remote_gguf("s3://repo/model:Q4_K")
+ assert not is_remote_gguf("gs://repo/model:Q8_0")
+
+
+class TestSplitRemoteGGUF:
+ """Test split_remote_gguf utility function."""
+
+ def test_split_remote_gguf_valid(self):
+ """Test split_remote_gguf with valid repo_id:quant_type format."""
+ repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
+ assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
+ assert quant_type == "IQ1_S"
+
+ repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
+ assert repo_id == "repo/model"
+ assert quant_type == "Q2_K"
+
+ def test_split_remote_gguf_with_path_object(self):
+ """Test split_remote_gguf with Path object."""
+ repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
+ assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
+ assert quant_type == "IQ1_S"
+
+ def test_split_remote_gguf_invalid(self):
+ """Test split_remote_gguf with invalid format."""
+ # Invalid format (no colon) - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("repo/model")
+
+ # Invalid quant type - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("repo/model:INVALID_TYPE")
+
+ # HTTP URL - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("http://repo/model:IQ1_S")
+
+ # Cloud storage - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("s3://bucket/repo/model:Q2_K")
+
+
+class TestIsGGUF:
+ """Test is_gguf utility function."""
+
+ @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
+ def test_is_gguf_with_local_file(self, mock_check_gguf):
+ """Test is_gguf with local GGUF file."""
+ assert is_gguf("/path/to/model.gguf")
+ assert is_gguf("./model.gguf")
+
+ def test_is_gguf_with_remote_gguf(self):
+ """Test is_gguf with remote GGUF format."""
+ # Valid remote GGUF format (repo_id:quant_type with valid quant_type)
+ assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
+ assert is_gguf("repo/model:Q2_K")
+ assert is_gguf("repo/model:Q4_K")
+
+ # Invalid quant_type should return False
+ assert not is_gguf("repo/model:quant")
+ assert not is_gguf("repo/model:INVALID")
+
+ @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
+ def test_is_gguf_false(self, mock_check_gguf):
+ """Test is_gguf returns False for non-GGUF models."""
+ assert not is_gguf("unsloth/Qwen3-0.6B")
+ assert not is_gguf("repo/model")
+ assert not is_gguf("model")
+
+ def test_is_gguf_edge_cases(self):
+ """Test is_gguf with edge cases."""
+ # Empty string
+ assert not is_gguf("")
+
+ # Only colon, no slash (even with valid quant_type)
+ assert not is_gguf("model:IQ1_S")
+
+ # Only slash, no colon
+ assert not is_gguf("repo/model")
+
+ # HTTP/HTTPS URLs
+ assert not is_gguf("http://repo/model:IQ1_S")
+ assert not is_gguf("https://repo/model:Q2_K")
+
+ # Cloud storage
+ assert not is_gguf("s3://bucket/repo/model:IQ1_S")
+ assert not is_gguf("gs://bucket/repo/model:Q2_K")
diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py
index 3310753d2b6d6..32d4eca541356 100644
--- a/tests/utils_/test_argparse_utils.py
+++ b/tests/utils_/test_argparse_utils.py
@@ -166,7 +166,7 @@ def test_dict_args(parser):
"--hf-overrides.key2.key4",
"val3",
# Test compile config and compilation mode
- "-O.use_inductor=true",
+ "-O.use_inductor_graph_partition=true",
"-O.backend",
"custom",
"-O1",
@@ -219,7 +219,7 @@ def test_dict_args(parser):
}
assert parsed_args.compilation_config == {
"mode": 1,
- "use_inductor": True,
+ "use_inductor_graph_partition": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
}
diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py
index 1bd05e6183dc2..783e02ce89bdb 100644
--- a/tests/v1/attention/test_mla_backends.py
+++ b/tests/v1/attention/test_mla_backends.py
@@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
- supported_sizes = backend.get_class().supported_kernel_block_sizes
+ supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py
new file mode 100644
index 0000000000000..80158d4b7278c
--- /dev/null
+++ b/tests/v1/attention/test_rocm_attention_backends_selection.py
@@ -0,0 +1,343 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for attention backend selectors."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.platforms import current_platform
+
+# ROCm-specific attention backend selection tests
+pytestmark = pytest.mark.skipif(
+ not current_platform.is_rocm(), reason="ROCm-specific tests"
+)
+
+
+@pytest.fixture
+def mock_vllm_config():
+ """Create a mock VllmConfig for testing."""
+ config = MagicMock()
+ config.model_config.dtype = torch.float16
+ config.model_config.hf_config.architectures = ["LlamaForCausalLM"]
+ config.cache_config.block_size = 16
+ return config
+
+
+@pytest.fixture
+def mock_on_gfx9():
+ """Mock the on_gfx9 function to return True."""
+ with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
+ yield
+
+
+@pytest.mark.parametrize(
+ "env_vars, selected_backend, expected_backend_path",
+ [
+ # Test Case: Explicit FLEX_ATTENTION backend
+ (
+ {},
+ "FLEX_ATTENTION",
+ AttentionBackendEnum.FLEX_ATTENTION.get_path(),
+ ),
+ # Test Case 1: Default (no env vars, no explicit backend)
+ (
+ {},
+ None,
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 2: Explicit TRITON_ATTN backend
+ (
+ {},
+ "TRITON_ATTN",
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 3: Explicit ROCM_ATTN backend
+ (
+ {},
+ "ROCM_ATTN",
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ # Test Case 4: Explicit ROCM_AITER_FA backend
+ (
+ {},
+ "ROCM_AITER_FA",
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend
+ (
+ {},
+ "ROCM_AITER_UNIFIED_ATTN",
+ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
+ ),
+ # Test Case 6: VLLM_ROCM_USE_AITER=1
+ # (defaults to AITER FA when MHA not explicitly disabled)
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
+ (
+ {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"},
+ None,
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
+ (
+ {
+ "VLLM_ROCM_USE_AITER": "1",
+ "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1",
+ },
+ None,
+ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
+ ),
+ # Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
+ (
+ {"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
+ None,
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ # Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "TRITON_ATTN",
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
+ # (explicitly disabled)
+ (
+ {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
+ None,
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "ROCM_ATTN",
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ ],
+)
+def test_standard_attention_backend_selection(
+ env_vars,
+ selected_backend,
+ expected_backend_path,
+ mock_vllm_config,
+ mock_on_gfx9,
+ monkeypatch,
+):
+ """Test standard attention backend selection with various configurations."""
+ # Set environment variables
+ for key, value in env_vars.items():
+ monkeypatch.setenv(key, value)
+
+ # Import after setting env vars to ensure they're picked up
+ # Reload envs to pick up new environment variables
+ import importlib
+
+ import vllm.envs as envs
+ from vllm.attention.backends.registry import _Backend
+
+ importlib.reload(envs)
+
+ # Convert string backend to enum if provided
+ backend_enum = None
+ if selected_backend:
+ backend_enum = getattr(_Backend, selected_backend)
+
+ # Get the backend class path
+ from vllm.platforms.rocm import RocmPlatform
+
+ backend_path = RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=False,
+ )
+ assert backend_path == expected_backend_path
+
+
+@pytest.mark.parametrize(
+ "env_vars, selected_backend, block_size, expected_backend_path, should_raise",
+ [
+ # Test Case 1: TRITON_MLA with block_size != 1
+ (
+ {},
+ "TRITON_MLA",
+ 16,
+ AttentionBackendEnum.TRITON_MLA.get_path(),
+ False,
+ ),
+ # Test Case 2: TRITON_MLA with block_size == 1 (should raise)
+ (
+ {},
+ "TRITON_MLA",
+ 1,
+ None,
+ True,
+ ),
+ # Test Case 3: ROCM_AITER_MLA with block_size == 1
+ (
+ {},
+ "ROCM_AITER_MLA",
+ 1,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise)
+ (
+ {},
+ "ROCM_AITER_MLA",
+ 16,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ 1,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16
+ # (should use ROCM_AITER_MLA now, as it supports block_size 16)
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ 16,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "TRITON_MLA",
+ 16,
+ AttentionBackendEnum.TRITON_MLA.get_path(),
+ False,
+ ),
+ # Test Case 8: Explicit ROCM_AITER_TRITON_MLA
+ (
+ {},
+ "ROCM_AITER_TRITON_MLA",
+ 16,
+ AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path(),
+ False,
+ ),
+ ],
+)
+def test_mla_backend_selection(
+ env_vars,
+ selected_backend,
+ block_size,
+ expected_backend_path,
+ should_raise,
+ mock_vllm_config,
+ monkeypatch,
+):
+ """Test MLA backend selection with various configurations."""
+ # Set environment variables
+ for key, value in env_vars.items():
+ monkeypatch.setenv(key, value)
+
+ # Import after setting env vars
+ # Reload envs
+ import importlib
+
+ import vllm.envs as envs
+ from vllm.attention.backends.registry import _Backend
+
+ importlib.reload(envs)
+
+ # Mock is_aiter_mla_enabled based on env vars and block_size
+ aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1"
+
+ mock_rocm_ops = MagicMock()
+ mock_rocm_ops.is_mla_enabled.return_value = aiter_enabled
+ mock_aiter_module = MagicMock()
+ mock_aiter_module.rocm_aiter_ops = mock_rocm_ops
+
+ with patch.dict("sys.modules", {"vllm._aiter_ops": mock_aiter_module}):
+ # Convert string backend to enum if provided
+ backend_enum = None
+ if selected_backend:
+ backend_enum = getattr(_Backend, selected_backend)
+
+ from vllm.platforms.rocm import RocmPlatform
+
+ if should_raise:
+ with pytest.raises(ValueError):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=block_size,
+ use_mla=True,
+ has_sink=False,
+ use_sparse=False,
+ )
+ else:
+ backend_path = RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=block_size,
+ use_mla=True,
+ has_sink=False,
+ use_sparse=False,
+ )
+ assert backend_path == expected_backend_path
+
+
+def test_aiter_fa_requires_gfx9(mock_vllm_config):
+ """Test that ROCM_AITER_FA requires gfx9 architecture."""
+ from vllm.attention.backends.registry import _Backend
+ from vllm.platforms.rocm import RocmPlatform
+
+ # Mock on_gfx9 to return False
+ with (
+ patch("vllm.platforms.rocm.on_gfx9", return_value=False),
+ pytest.raises(
+ ValueError,
+ match="only supported on gfx9",
+ ),
+ ):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=_Backend.ROCM_AITER_FA,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=False,
+ )
+
+
+def test_sparse_not_supported(mock_vllm_config):
+ """Test that sparse attention is not supported on ROCm."""
+ from vllm.platforms.rocm import RocmPlatform
+
+ with pytest.raises(
+ AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
+ ):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=None,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=True,
+ )
diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py
index dea89babd4b47..df3d53332c7cd 100644
--- a/tests/v1/attention/utils.py
+++ b/tests/v1/attention/utils.py
@@ -340,4 +340,11 @@ full_cg_backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
),
+ "RocmAttn": BackendConfig(
+ name="RocmAttn",
+ env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
+ comp_config={
+ "cudagraph_mode": "FULL",
+ },
+ ),
}
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index 24611a4aaa1b8..58a7a2692bfc8 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
)
# Test case 1: Requires additional lookahead tokens
- kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
+ kv_cache_manager = KVCacheManager(
+ kv_cache_config=config, max_model_len=100, hash_block_size=block_size
+ )
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
@@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
- kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
+ kv_cache_manager = KVCacheManager(
+ kv_cache_config=config, max_model_len=100, hash_block_size=block_size
+ )
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
request,
@@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
- kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
+ kv_cache_manager = KVCacheManager(
+ kv_cache_config=config, max_model_len=100, hash_block_size=block_size
+ )
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
@@ -1436,7 +1442,67 @@ def test_get_kv_cache_config_one_worker():
],
)
- # different hidden size
+ # 6 full + 5 sliding, pad to 6 full + 6 sliding. This is a typical case for gpt-oss
+ # eagle where there is only one more full attention layer than sliding window layers
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(),
+ "layer_2": new_kv_cache_spec(),
+ "layer_3": new_kv_cache_spec(),
+ "layer_4": new_kv_cache_spec(),
+ "layer_5": new_kv_cache_spec(),
+ "layer_6": new_kv_cache_spec(),
+ "layer_7": new_sliding_window_spec(),
+ "layer_8": new_sliding_window_spec(),
+ "layer_9": new_sliding_window_spec(),
+ "layer_10": new_sliding_window_spec(),
+ "layer_11": new_sliding_window_spec(),
+ }
+
+ kv_cache_config_hybrid = get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 6 * 32]
+ )[0]
+ print(kv_cache_config_hybrid)
+ assert kv_cache_config_hybrid == KVCacheConfig(
+ num_blocks=32,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_1", "layer_7"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_2", "layer_8"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_3", "layer_9"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_4", "layer_10"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_5", "layer_11"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_6"],
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer_1", "layer_2", "layer_3", "layer_4", "layer_5", "layer_6"],
+ new_kv_cache_spec(),
+ ),
+ KVCacheGroupSpec(
+ ["layer_7", "layer_8", "layer_9", "layer_10", "layer_11"],
+ new_sliding_window_spec(),
+ ),
+ ],
+ )
+
+ # different hidden size but same type, use UniformTypeKVCacheSpecs
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=128),
"layer_2": new_kv_cache_spec(head_size=64),
@@ -1460,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
],
)
+ # Different hidden size and different type, align by different block size
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(head_size=64),
+ "layer_2": new_sliding_window_spec(head_size=32),
+ }
+ kv_cache_config_hybrid = get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
+ )[0]
+ assert kv_cache_config_hybrid == KVCacheConfig(
+ num_blocks=32,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
+ KVCacheGroupSpec(
+ ["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
+ ),
+ ],
+ )
+
+ # different hidden size that cannot be aligned by using different block size
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(head_size=64),
+ "layer_2": new_sliding_window_spec(head_size=96),
+ }
+
+ with pytest.raises(NotImplementedError):
+ get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
+ )[0]
+
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_configs(
diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py
index 2291f363731f2..64fd5ab1dd9aa 100644
--- a/tests/v1/core/test_prefix_caching.py
+++ b/tests/v1/core/test_prefix_caching.py
@@ -134,6 +134,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
hash_fn = sha256
@@ -416,6 +418,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# the default hash function is sha256
hash_fn = sha256
@@ -523,6 +526,7 @@ def test_decode():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@@ -585,6 +589,7 @@ def test_evict():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
last_token_id = 5 * 16 + 7
@@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Allocate 1 block and cache it.
@@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Allocate a block and cache it.
@@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
+ hash_block_size=block_size,
)
req1 = make_request(
@@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
+ hash_block_size=block_size,
)
# Req:
# Block 0: [0, 1, 2, 3]
@@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size = 4
- block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
+ )
# Req:
# Block 0/4: [0, 1, 2, 3]
@@ -921,6 +932,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
@@ -1020,6 +1032,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# 3 complete blocks and an incomplete block with 11 tokens.
@@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
@@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]
@@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
log_stats=False, # Disable logging stats
)
assert manager.prefix_cache_stats is None
@@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block():
- pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
+ pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
@@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
+ hash_block_size=block_size,
)
num_tokens = block_size * blocks_to_cache
@@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
+ hash_block_size=block_size,
)
# Test with LoRA request
@@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
+ hash_block_size=block_size,
)
# Request with 3 full blocks (48 tokens)
@@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
+ hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
@@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
+ hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
@@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0
+def test_different_block_size():
+ block_size = 16
+ # full attention and sliding window attention layers have the same page size:
+ # (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
+ kv_cache_config = KVCacheConfig(
+ num_blocks=100,
+ kv_cache_tensors=[],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer1"],
+ FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
+ ),
+ KVCacheGroupSpec(
+ ["layer2"],
+ SlidingWindowSpec(
+ block_size,
+ 1,
+ 1,
+ torch.float32,
+ sliding_window=2 * block_size,
+ ),
+ ),
+ ],
+ )
+ manager = KVCacheManager(
+ kv_cache_config=kv_cache_config,
+ max_model_len=8192,
+ enable_caching=True,
+ hash_block_size=block_size,
+ )
+
+ # 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
+ common_token_ids = [i for i in range(10) for _ in range(block_size)]
+
+ req0 = make_request("0", common_token_ids, block_size, sha256)
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
+ assert not computed_blocks.blocks[0]
+ assert not computed_blocks.blocks[1]
+ assert num_computed_tokens == 0
+ blocks = manager.allocate_slots(
+ req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
+ )
+ assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
+ req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
+ assert len(computed_blocks.blocks[0]) == 3
+ assert len(computed_blocks.blocks[1]) == 6
+ assert num_computed_tokens == 6 * 16
+
+ req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
+ assert len(computed_blocks.blocks[0]) == 3
+ assert len(computed_blocks.blocks[1]) == 6
+ assert num_computed_tokens == 6 * 16
+
+ # Evict some blocks to make sliding window cache hit length 5*16
+ # But should return 4 * 16 because full attention cache hit length must be
+ # a multiple of 32
+ manager.block_pool.cached_block_hash_to_block.pop(
+ make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
+ )
+ manager.block_pool.cached_block_hash_to_block.pop(
+ make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
+ )
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
+ assert len(computed_blocks.blocks[0]) == 2
+ assert len(computed_blocks.blocks[1]) == 4
+ assert num_computed_tokens == 4 * 16
+
+
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py
index 04e738293cd77..fe4153e609971 100644
--- a/tests/v1/core/test_scheduler.py
+++ b/tests/v1/core/test_scheduler.py
@@ -76,11 +76,11 @@ def test_get_num_unfinished_requests():
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
- (None, None),
+ (False, None),
(True, 5),
],
)
-def test_schedule(enable_prefix_caching: bool | None, prompt_logprobs: int | None):
+def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None):
"""Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
"""
@@ -582,12 +582,12 @@ def test_check_stop_min_tokens():
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
- (None, None),
+ (False, None),
(True, 5),
],
)
def test_schedule_concurrent_batches(
- enable_prefix_caching: bool | None, prompt_logprobs: int | None
+ enable_prefix_caching: bool, prompt_logprobs: int | None
):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
@@ -641,6 +641,34 @@ def test_schedule_concurrent_batches(
scheduler.update_from_output(scheduler_output1, model_runner_output)
+@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
+def test_schedule_order(enable_chunked_prefill: bool):
+ scheduler = create_scheduler(
+ max_num_batched_tokens=1024,
+ max_num_seqs=3,
+ enable_chunked_prefill=enable_chunked_prefill,
+ )
+
+ # long requests
+ requests = create_requests(num_requests=2, num_tokens=800)
+ # short requests
+ requests += create_requests(num_requests=2, num_tokens=10)
+
+ for request in requests:
+ scheduler.add_request(request)
+
+ scheduler_output1 = scheduler.schedule()
+
+ if enable_chunked_prefill:
+ # When enable chunked prefill, long requests will be chunked.
+ assert len(scheduler_output1.scheduled_new_reqs) == 2
+ else:
+ # When disable chunked prefill, should not skip the long requests,
+ # and scheduling subsequent short requests in advance,
+ # even though there is still token budgets remaining.
+ assert len(scheduler_output1.scheduled_new_reqs) == 1
+
+
def test_preempt_during_execution():
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
# because block 0 is reserved as the null block.
@@ -1057,7 +1085,8 @@ def test_kv_connector_basic(is_async: bool):
)
-def test_external_prefix_cache_metrics():
+@pytest.mark.parametrize("is_async", [False, True])
+def test_external_prefix_cache_metrics(is_async: bool):
"""
Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits.
@@ -1067,7 +1096,9 @@ def test_external_prefix_cache_metrics():
NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler(
enable_prefix_caching=False,
- use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
+ use_kv_connector=mock_kv(
+ matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
+ ),
)
# --- Prepare simple requests ---
@@ -1079,9 +1110,15 @@ def test_external_prefix_cache_metrics():
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
)
+ req_ids = []
+ req_to_index = {}
+ for i, request in enumerate(requests):
+ scheduler.add_request(request)
+ req_ids.append(request.request_id)
+ req_to_index[request.request_id] = i
- for req in requests:
- scheduler.add_request(req)
+ if is_async:
+ _step_until_kv_transfer_finished(scheduler, req_ids)
# --- Trigger scheduling and simulate model output ---
output = scheduler.schedule()
@@ -1416,7 +1453,7 @@ def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
- enable_prefix_caching: bool | None = None,
+ enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
@@ -1435,7 +1472,7 @@ def create_scheduler_with_priority(
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
- (None)
+ (False)
Returns:
{class}`Scheduler` instance with priority scheduling
@@ -1458,17 +1495,12 @@ def create_scheduler_with_priority(
seed=42,
)
# Cache config, optionally force APC
- kwargs_cache = (
- {}
- if enable_prefix_caching is None
- else {"enable_prefix_caching": enable_prefix_caching}
- )
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
- **kwargs_cache,
+ enable_prefix_caching=enable_prefix_caching,
)
kv_transfer_config = (
KVTransferConfig(
diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py
index a27f32938c08b..e6a69dc8a949a 100644
--- a/tests/v1/core/test_single_type_kv_cache_manager.py
+++ b/tests/v1/core/test_single_type_kv_cache_manager.py
@@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
attention_chunk_size=4,
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool
)
@@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False,
+ alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
sliding_window=4,
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length):
@@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False,
+ alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
attention_chunk_size=4,
)
- block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
+ block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
@@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
sliding_window=4,
)
- block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
+ block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
@@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
sliding_window=4, # Placeholder value, not related to test result
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
@@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
attention_chunk_size=4, # Placeholder value, not related to test result
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py
index 65511c17473b2..7537c7a60476b 100644
--- a/tests/v1/core/utils.py
+++ b/tests/v1/core/utils.py
@@ -42,7 +42,8 @@ def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
- enable_prefix_caching: bool | None = None,
+ enable_chunked_prefill: bool = True,
+ enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: None | bool | MockKVConfig = None,
@@ -63,7 +64,7 @@ def create_scheduler(
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
- (None)
+ (False)
Returns:
{class}`Scheduler` instance
@@ -76,7 +77,7 @@ def create_scheduler(
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
- enable_chunked_prefill=True,
+ enable_chunked_prefill=enable_chunked_prefill,
async_scheduling=async_scheduling,
)
model_config = ModelConfig(
@@ -87,17 +88,12 @@ def create_scheduler(
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
- kwargs_cache = (
- {}
- if enable_prefix_caching is None
- else {"enable_prefix_caching": enable_prefix_caching}
- )
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
- **kwargs_cache,
+ enable_prefix_caching=enable_prefix_caching,
)
kv_transfer_config = None
if isinstance(use_kv_connector, MockKVConfig):
diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py
index d6bde16eba36b..7f9c2a0571c3c 100644
--- a/tests/v1/cudagraph/test_cudagraph_mode.py
+++ b/tests/v1/cudagraph/test_cudagraph_mode.py
@@ -35,14 +35,22 @@ def temporary_environ(env_vars):
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
-combo_cases_1 = [
- ("FA3", "FULL", True),
- ("FA3", "FULL_AND_PIECEWISE", True),
- ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
- ("FA2", "FULL_AND_PIECEWISE", True),
- ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
- ("FlashInfer", "FULL_AND_PIECEWISE", True),
-]
+if current_platform.is_rocm():
+ combo_cases_1 = [
+ ("RocmAttn", "FULL", True),
+ ("RocmAttn", "FULL_AND_PIECEWISE", True),
+ ("TritonAttn", "FULL", True),
+ ("TritonAttn", "FULL_AND_PIECEWISE", True),
+ ]
+else:
+ combo_cases_1 = [
+ ("FA3", "FULL", True),
+ ("FA3", "FULL_AND_PIECEWISE", True),
+ ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
+ ("FA2", "FULL_AND_PIECEWISE", True),
+ ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
+ ("FlashInfer", "FULL_AND_PIECEWISE", True),
+ ]
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
@@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
-combo_cases_2 = [
- ("FA2", "FULL", CompilationMode.NONE, True),
- ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "PIECEWISE", CompilationMode.NONE, False),
- ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
- ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
- ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "NONE", CompilationMode.NONE, True),
- ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
-]
+if current_platform.is_rocm():
+ combo_cases_2 = [
+ ("RocmAttn", "FULL", CompilationMode.NONE, True),
+ ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
+ ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
+ ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
+ ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "NONE", CompilationMode.NONE, True),
+ ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
+ ]
+else:
+ combo_cases_2 = [
+ ("FA2", "FULL", CompilationMode.NONE, True),
+ ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "PIECEWISE", CompilationMode.NONE, False),
+ ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
+ ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
+ ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "NONE", CompilationMode.NONE, True),
+ ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
+ ]
@pytest.mark.parametrize(
diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py
new file mode 100644
index 0000000000000..9f6a6614fc1fd
--- /dev/null
+++ b/tests/v1/distributed/test_eagle_dp.py
@@ -0,0 +1,77 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import asyncio
+import os
+from contextlib import AsyncExitStack
+from dataclasses import replace
+
+import pytest
+
+from vllm import SamplingParams
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.sampling_params import RequestOutputKind
+from vllm.v1.engine.async_llm import AsyncLLM
+
+DP_SIZE = int(os.getenv("DP_SIZE", 2))
+
+
+@pytest.mark.asyncio
+async def test_run_eagle_dp():
+ target_model = "meta-llama/Llama-3.1-8B-Instruct"
+ draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
+
+ engine_args = AsyncEngineArgs(
+ model=target_model,
+ tokenizer_mode="auto",
+ enforce_eager=False,
+ tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
+ data_parallel_size=DP_SIZE,
+ data_parallel_backend="mp", # ray takes more time
+ trust_remote_code=True,
+ max_model_len=16384,
+ )
+
+ eagle_engine_args = replace(
+ engine_args,
+ speculative_config={
+ "model": draft_model,
+ "method": "eagle",
+ "num_speculative_tokens": 3,
+ },
+ )
+
+ prompt = "This is a test of data parallel with eagle"
+ num_expected_tokens = 100
+ sampling_params = SamplingParams(
+ min_tokens=num_expected_tokens,
+ max_tokens=num_expected_tokens,
+ ignore_eos=True,
+ output_kind=RequestOutputKind.FINAL_ONLY,
+ temperature=0,
+ )
+
+ async def generate_with_timeout(given_engine: AsyncLLM):
+ async for out in given_engine.generate(
+ request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
+ ):
+ token_ids = out.outputs[0].token_ids
+ assert len(token_ids) == num_expected_tokens
+ return token_ids
+
+ async def engine_create_and_generate(engine_args: AsyncEngineArgs):
+ async with AsyncExitStack() as after:
+ engine = AsyncLLM.from_engine_args(engine_args)
+ after.callback(engine.shutdown)
+
+ token_ids = await asyncio.wait_for(
+ generate_with_timeout(engine), timeout=30
+ )
+
+ assert not engine.output_processor.has_unfinished_requests()
+ return token_ids
+
+ token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
+ token_ids_no_eagle = await engine_create_and_generate(engine_args)
+
+ # Test for correctness
+ assert token_ids_with_eagle == token_ids_no_eagle
diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py
index d1b037b7956cf..85f108786c05a 100644
--- a/tests/v1/entrypoints/llm/test_struct_output_generate.py
+++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py
@@ -3,7 +3,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
-from dataclasses import fields
from enum import Enum
from typing import TYPE_CHECKING, Any
@@ -21,7 +20,6 @@ from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import (
- GuidedDecodingParams,
SamplingParams,
StructuredOutputsParams,
)
@@ -108,23 +106,6 @@ class CarDescription(BaseModel):
car_type: CarType
-def test_guided_decoding_deprecated():
- with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"):
- guided_decoding = GuidedDecodingParams(json_object=True)
-
- structured_outputs = StructuredOutputsParams(json_object=True)
- assert fields(guided_decoding) == fields(structured_outputs)
-
- with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
- sp1 = SamplingParams(guided_decoding=guided_decoding)
-
- with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
- sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
-
- assert sp1 == sp2
- assert sp1.structured_outputs == guided_decoding
-
-
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE,
@@ -899,13 +880,11 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
output_json = json.loads(generated_text)
-@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
-def test_structured_output_with_structural_tag(
- guided_decoding_backend: str,
-):
+@pytest.mark.parametrize("backend", ["xgrammar"])
+def test_structured_output_with_structural_tag(backend: str):
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
- guided_decoding_backend=guided_decoding_backend,
+ structured_outputs_config=StructuredOutputsConfig(backend=backend),
)
structural_tag_config = {
@@ -923,7 +902,7 @@ def test_structured_output_with_structural_tag(
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
- guided_decoding=StructuredOutputsParams(
+ structured_outputs=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)
),
)
diff --git a/tests/v1/kv_connector/unit/test_lmcache_integration.py b/tests/v1/kv_connector/unit/test_lmcache_integration.py
index 11507d7cd4e7b..33418edc325af 100644
--- a/tests/v1/kv_connector/unit/test_lmcache_integration.py
+++ b/tests/v1/kv_connector/unit/test_lmcache_integration.py
@@ -9,6 +9,12 @@
# Assumption vs. Correctness Tests:
# these unit tests do *not* test correctness of LMCache-side or vLLM-side logic
# it is to ensure that assumptions LMCache makes about vLLM's interface are stable
+
+import pytest
+
+from vllm.platforms import current_platform
+
+
def assumes(obj, attr, is_callable=False, is_instance_of=None):
import inspect
from dataclasses import is_dataclass
@@ -48,6 +54,9 @@ def assumes(obj, attr, is_callable=False, is_instance_of=None):
assert isinstance(attr_value, is_instance_of), assumption_msg
+@pytest.mark.skipif(
+ current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
+)
def test_multimodal_interface():
# protect against interface changes
from vllm.multimodal.inputs import PlaceholderRange
@@ -72,6 +81,9 @@ def test_multimodal_interface():
assert token_ids.tolist() == [0, 0, 0, 0, 4, 4369, 4369, 4369, 4369, 9]
+@pytest.mark.skipif(
+ current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
+)
def test_config_interface():
# protect against interface changes
from vllm.config import VllmConfig
@@ -146,6 +158,9 @@ def test_config_interface():
)
+@pytest.mark.skipif(
+ current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
+)
def test_request_interface():
# protect against interface changes
from types import NoneType
diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py
index 1c1ac915c758e..ffa7d884d2762 100644
--- a/tests/v1/kv_connector/unit/test_multi_connector.py
+++ b/tests/v1/kv_connector/unit/test_multi_connector.py
@@ -20,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
+from vllm.platforms import current_platform
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@@ -69,6 +70,13 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
return True
+@pytest.mark.skipif(
+ current_platform.is_rocm(),
+ reason=(
+ "hipErrorLaunchFailure when running this test, see issue:"
+ "https://github.com/ROCm/pytorch/issues/2822"
+ ),
+)
def test_multi_shared_storage_connector_consistency():
"""
Tests that MultiConnector with two SharedStorageConnectors saves
diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py
index 3ee41c40859dc..406d4c0b4c1fd 100644
--- a/tests/v1/kv_offload/test_cpu_offloading.py
+++ b/tests/v1/kv_offload/test_cpu_offloading.py
@@ -12,10 +12,14 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
+from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48]
-ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
+ATTN_BACKENDS = ["FLASH_ATTN"]
+
+if current_platform.is_cuda():
+ ATTN_BACKENDS.append("FLASHINFER")
class MockSubscriber:
diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py
index 42584938bc06f..c89c33be80c10 100644
--- a/tests/v1/sample/test_logprobs.py
+++ b/tests/v1/sample/test_logprobs.py
@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
pytest.param(
(
"eagle",
- "meta-llama/Llama-3.1-8B-Instruct",
- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
+ "meta-llama/Llama-3.2-1B-Instruct",
+ "nm-testing/Llama3_2_1B_speculator.eagle3",
),
marks=large_gpu_mark(min_gb=32),
),
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
"""
from vllm import LLM
- prompt = "Hello world"
+ prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
+ # Force prefill chunking
+ enable_chunked_prefill=True,
+ max_num_batched_tokens=32,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
@@ -597,6 +600,84 @@ def test_spec_decode_logprobs(
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
- assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
+ assert math.isclose(
+ ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
+ )
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
+
+
+def test_prompt_logprobs_with_chunking_and_preemption():
+ """Test that prompt logprobs are correctly returned when using
+ both chunked prefill and preemption.
+
+ This test ensures that the num_prompt_logprobs tracking persists
+ across preemptions and prefill chunks.
+ """
+
+ # Create prompts that will trigger chunking and preemption
+ prompts = [
+ "The following numbers of the sequence "
+ + ", ".join(str(i) for i in range(10))
+ + " are:",
+ "In one word, the capital of France is ",
+ ] + [f"Tell me about the number {i}: " for i in range(32)]
+
+ sampling_params = SamplingParams(
+ temperature=0.0,
+ max_tokens=40,
+ min_tokens=20,
+ prompt_logprobs=2, # Request prompt logprobs
+ )
+
+ with VllmRunner(
+ "Qwen/Qwen3-0.6B",
+ max_model_len=512,
+ enable_chunked_prefill=True,
+ max_num_batched_tokens=48, # Force prefill chunking
+ num_gpu_blocks_override=32, # Force preemptions
+ disable_log_stats=False,
+ gpu_memory_utilization=0.25,
+ ) as vllm_model:
+ metrics_before = vllm_model.llm.get_metrics()
+
+ # Generate with prompt logprobs using generate_w_logprobs which
+ # returns (output_ids, output_str, output_logprobs, prompt_logprobs)
+ outputs = vllm_model.generate_w_logprobs(
+ prompts, sampling_params=sampling_params, include_prompt_token_ids=True
+ )
+
+ # Verify that all outputs have prompt logprobs
+ for i, output in enumerate(outputs):
+ _, _, _, prompt_token_ids, prompt_logprobs = output
+ assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
+ f"Output {i} missing prompt logprobs"
+ )
+ assert len(prompt_logprobs) == len(prompt_token_ids), (
+ "Unexpected number of prompt logprob positions"
+ )
+
+ # Each position should have the requested number of logprobs
+ for pos, logprobs_dict in enumerate(prompt_logprobs):
+ if logprobs_dict is not None: # First token may be None
+ assert (
+ sampling_params.prompt_logprobs
+ <= len(logprobs_dict)
+ <= sampling_params.prompt_logprobs + 1
+ ), (
+ f"Output {i} position {pos} has {len(logprobs_dict)} "
+ f"logprobs, expected {sampling_params.prompt_logprobs}"
+ )
+
+ # Check that we actually had preemptions
+ metrics_after = vllm_model.llm.get_metrics()
+ preemptions_before = next(
+ (m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
+ )
+ preemptions_after = next(
+ (m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
+ )
+ preemptions = preemptions_after - preemptions_before
+ assert preemptions > 0, "Test did not trigger any preemptions"
+
+ print(f"Test passed with {preemptions} preemptions")
diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py
index 6958d62dc7e90..a4ee53008ce82 100644
--- a/tests/v1/spec_decode/test_tree_attention.py
+++ b/tests/v1/spec_decode/test_tree_attention.py
@@ -3,6 +3,7 @@
import math
+import pytest
import torch
from tests.v1.attention.utils import (
@@ -11,9 +12,16 @@ from tests.v1.attention.utils import (
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
+if not is_flash_attn_varlen_func_available():
+ pytest.skip(
+ "This test requires flash_attn_varlen_func, but it's not available.",
+ allow_module_level=True,
+ )
+
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py
index 01c1364f7ee62..d0f1b703fcb92 100644
--- a/tests/v1/worker/test_gpu_model_runner.py
+++ b/tests/v1/worker/test_gpu_model_runner.py
@@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
- supported_kernel_block_sizes = supported_sizes
+ @staticmethod
+ def get_supported_kernel_block_sizes():
+ return supported_sizes
return _MockBackend()
diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh
index 5ea543f4cb1e8..1cea1bef8dbc9 100755
--- a/tools/ep_kernels/install_python_libraries.sh
+++ b/tools/ep_kernels/install_python_libraries.sh
@@ -1,94 +1,79 @@
#!/usr/bin/env bash
set -ex
-# prepare workspace directory
-WORKSPACE=$1
-if [ -z "$WORKSPACE" ]; then
- export WORKSPACE=$(pwd)/ep_kernels_workspace
-fi
+# usage: ./build.sh [workspace_dir] [mode]
+# mode: "install" (default) → install directly into current Python env
+# "wheel" → build wheels into WORKSPACE/dist
-if [ ! -d "$WORKSPACE" ]; then
- mkdir -p $WORKSPACE
-fi
+WORKSPACE=${1:-$(pwd)/ep_kernels_workspace}
+MODE=${2:-install}
+mkdir -p "$WORKSPACE"
+
+WHEEL_DIR="$WORKSPACE/dist"
+mkdir -p "$WHEEL_DIR"
+NVSHMEM_VER=3.3.9
+
+pushd "$WORKSPACE"
-# configurable pip command (default: pip3)
-PIP_CMD=${PIP_CMD:-pip3}
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
# install dependencies if not installed
-$PIP_CMD install cmake torch ninja
-
-# build nvshmem
-pushd $WORKSPACE
-mkdir -p nvshmem_src
-wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
-tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1
-pushd nvshmem_src
-wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
-git init
-git apply -vvv nvshmem.patch
-
-# assume CUDA_HOME is set correctly
-if [ -z "$CUDA_HOME" ]; then
- echo "CUDA_HOME is not set, please set it to your CUDA installation directory."
- exit 1
+if [ -z "$VIRTUAL_ENV" ]; then
+ uv pip install --system cmake torch ninja
+else
+ uv pip install cmake torch ninja
fi
-# assume TORCH_CUDA_ARCH_LIST is set correctly
-if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then
- echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture."
+# fetch nvshmem
+ARCH=$(uname -m)
+case "${ARCH,,}" in
+ x86_64|amd64)
+ NVSHMEM_SUBDIR="linux-x86_64"
+ NVSHMEM_FILE="libnvshmem-linux-x86_64-${NVSHMEM_VER}_cuda12-archive.tar.xz"
+ ;;
+ aarch64|arm64)
+ NVSHMEM_SUBDIR="linux-sbsa"
+ NVSHMEM_FILE="libnvshmem-linux-sbsa-${NVSHMEM_VER}_cuda12-archive.tar.xz"
+ ;;
+ *)
+ echo "Unsupported architecture: ${ARCH}" >&2
exit 1
-fi
+ ;;
+esac
-# disable all features except IBGDA
-export NVSHMEM_IBGDA_SUPPORT=1
-
-export NVSHMEM_SHMEM_SUPPORT=0
-export NVSHMEM_UCX_SUPPORT=0
-export NVSHMEM_USE_NCCL=0
-export NVSHMEM_PMIX_SUPPORT=0
-export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
-export NVSHMEM_USE_GDRCOPY=0
-export NVSHMEM_IBRC_SUPPORT=0
-export NVSHMEM_BUILD_TESTS=0
-export NVSHMEM_BUILD_EXAMPLES=0
-export NVSHMEM_MPI_SUPPORT=0
-export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
-export NVSHMEM_BUILD_TXZ_PACKAGE=0
-export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
-
-cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install
-cmake --build $WORKSPACE/nvshmem_build/ --target install
+NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}"
+pushd "$WORKSPACE"
+echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..."
+curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}"
+tar -xf "${NVSHMEM_FILE}"
+mv "${NVSHMEM_FILE%.tar.xz}" nvshmem
+rm -f "${NVSHMEM_FILE}"
+rm -rf nvshmem/lib/bin nvshmem/lib/share
popd
-export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
+export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem/lib/cmake:$CMAKE_PREFIX_PATH
is_git_dirty() {
local dir=$1
pushd "$dir" > /dev/null
-
- if [ -d ".git" ] && [ -n "$(git status --porcelain 2>/dev/null)" ]; then
+ if [ -d ".git" ] && [ -n "$(git status --porcelain 3>/dev/null)" ]; then
popd > /dev/null
- return 0 # dirty (true)
+ return 0
else
popd > /dev/null
- return 1 # clean (false)
+ return 1
fi
}
-# Function to handle git repository cloning with dirty/incomplete checks
clone_repo() {
local repo_url=$1
local dir_name=$2
local key_file=$3
local commit_hash=$4
-
if [ -d "$dir_name" ]; then
- # Check if directory has uncommitted changes (dirty)
if is_git_dirty "$dir_name"; then
echo "$dir_name directory is dirty, skipping clone"
- # Check if clone failed (directory exists but not a valid git repo or missing key files)
elif [ ! -d "$dir_name/.git" ] || [ ! -f "$dir_name/$key_file" ]; then
echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning"
rm -rf "$dir_name"
@@ -99,7 +84,7 @@ clone_repo() {
cd ..
fi
else
- echo "$dir_name directory exists and appears complete; manually update if needed"
+ echo "$dir_name directory exists and appears complete"
fi
else
git clone "$repo_url"
@@ -111,17 +96,55 @@ clone_repo() {
fi
}
-# build and install pplx, require pytorch installed
-pushd $WORKSPACE
-clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf"
-cd pplx-kernels
-$PIP_CMD install --no-build-isolation -vvv -e .
-popd
+deepep_cuda13_patch() {
+ cuda_version_major=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
+ if [ ${cuda_version_major} -ge 13 ]; then
+ sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
+ fi
+}
-# build and install deepep, require pytorch installed
-pushd $WORKSPACE
-clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "73b6ea4"
-cd DeepEP
-export NVSHMEM_DIR=$WORKSPACE/nvshmem_install
-$PIP_CMD install --no-build-isolation -vvv -e .
-popd
+do_build() {
+ local repo=$1
+ local name=$2
+ local key=$3
+ local commit=$4
+ local extra_env=$5
+
+ pushd "$WORKSPACE"
+ clone_repo "$repo" "$name" "$key" "$commit"
+ cd "$name"
+
+ if [ "$name" == "DeepEP" ]; then
+ deepep_cuda13_patch
+ fi
+
+ if [ "$MODE" = "install" ]; then
+ echo "Installing $name into environment"
+ eval "$extra_env" uv pip install --no-build-isolation -vvv .
+ else
+ echo "Building $name wheel into $WHEEL_DIR"
+ eval "$extra_env" uv build --wheel --no-build-isolation -vvv --out-dir "$WHEEL_DIR" .
+ fi
+ popd
+}
+
+# build pplx-kernels
+do_build \
+ "https://github.com/ppl-ai/pplx-kernels" \
+ "pplx-kernels" \
+ "setup.py" \
+ "12cecfd" \
+ ""
+
+# build DeepEP
+do_build \
+ "https://github.com/deepseek-ai/DeepEP" \
+ "DeepEP" \
+ "setup.py" \
+ "73b6ea4" \
+ "export NVSHMEM_DIR=$WORKSPACE/nvshmem; "
+
+if [ "$MODE" = "wheel" ]; then
+ echo "All wheels written to $WHEEL_DIR"
+ ls -l "$WHEEL_DIR"
+fi
diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh
index 4f2cd302c3eff..ee9a5dd4aa643 100755
--- a/tools/install_deepgemm.sh
+++ b/tools/install_deepgemm.sh
@@ -1,12 +1,13 @@
#!/bin/bash
-# Script to install DeepGEMM from source
-# This script can be used both in Docker builds and by users locally
-
+# Script to build and/or install DeepGEMM from source
+# Default: build and install immediately
+# Optional: build wheels to a directory for later installation (useful in multi-stage builds)
set -e
# Default values
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="594953acce41793ae00a1233eb516044d604bcb6"
+WHEEL_DIR=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
@@ -27,11 +28,20 @@ while [[ $# -gt 0 ]]; do
CUDA_VERSION="$2"
shift 2
;;
+ --wheel-dir)
+ if [[ -z "$2" || "$2" =~ ^- ]]; then
+ echo "Error: --wheel-dir requires a directory path." >&2
+ exit 1
+ fi
+ WHEEL_DIR="$2"
+ shift 2
+ ;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)"
echo " --cuda-version VER CUDA version (auto-detected if not provided)"
+ echo " --wheel-dir PATH If set, build wheel into PATH but do not install"
echo " -h, --help Show this help message"
exit 0
;;
@@ -57,16 +67,15 @@ fi
CUDA_MAJOR="${CUDA_VERSION%%.*}"
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
CUDA_MINOR="${CUDA_MINOR%%.*}"
-
echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)"
# Check CUDA version requirement
if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then
- echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
+ echo "Skipping DeepGEMM build/installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
exit 0
fi
-echo "Installing DeepGEMM from source..."
+echo "Preparing DeepGEMM build..."
echo "Repository: $DEEPGEMM_GIT_REPO"
echo "Reference: $DEEPGEMM_GIT_REF"
@@ -76,23 +85,31 @@ trap 'rm -rf "$INSTALL_DIR"' EXIT
# Clone the repository
git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm"
-
-echo "🏗️ Building DeepGEMM"
pushd "$INSTALL_DIR/deepgemm"
# Checkout the specific reference
git checkout "$DEEPGEMM_GIT_REF"
-# Build DeepGEMM
+# Clean previous build artifacts
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
-rm -rf build dist
-rm -rf *.egg-info
+rm -rf build dist *.egg-info
+
+# Build wheel
+echo "🏗️ Building DeepGEMM wheel..."
python3 setup.py bdist_wheel
-# Install the wheel
+# If --wheel-dir was specified, copy wheels there and exit
+if [ -n "$WHEEL_DIR" ]; then
+ mkdir -p "$WHEEL_DIR"
+ cp dist/*.whl "$WHEEL_DIR"/
+ echo "✅ Wheel built and copied to $WHEEL_DIR"
+ popd
+ exit 0
+fi
+
+# Default behaviour: install built wheel
if command -v uv >/dev/null 2>&1; then
echo "Installing DeepGEMM wheel using uv..."
- # Use --system in Docker contexts, respect user's environment otherwise
if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then
uv pip install --system dist/*.whl
else
@@ -104,5 +121,4 @@ else
fi
popd
-
echo "✅ DeepGEMM installation completed successfully"
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
index db79b3f5e8bcb..a8f472d147a0d 100644
--- a/vllm/_aiter_ops.py
+++ b/vllm/_aiter_ops.py
@@ -294,6 +294,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
) -> None:
from aiter.mla import mla_decode_fwd
@@ -308,6 +310,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
+ q_scale=q_scale,
+ kv_scale=kv_scale,
)
@@ -322,6 +326,8 @@ def _rocm_aiter_mla_decode_fwd_fake(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
) -> None:
pass
@@ -806,6 +812,8 @@ class rocm_aiter_ops:
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
@@ -818,6 +826,8 @@ class rocm_aiter_ops:
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
+ q_scale=q_scale,
+ kv_scale=kv_scale,
)
@staticmethod
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 0f625a7945241..4a1bcc761f994 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
- batch_size: int,
+ token_to_seq: torch.Tensor,
+ num_tokens: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: torch.Tensor | None = None,
@@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ num_tokens,
kv_cache_dtype,
scale,
seq_starts,
diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py
index 67ded88475243..bd7e81b15bfc3 100644
--- a/vllm/attention/backends/abstract.py
+++ b/vllm/attention/backends/abstract.py
@@ -46,9 +46,12 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(1)]
+
@staticmethod
@abstractmethod
def get_name() -> str:
@@ -142,10 +145,11 @@ class AttentionBackend(ABC):
if block_size not in valid_sizes:
return False
- if not cls.supported_kernel_block_sizes:
+ supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
+ if not supported_kernel_block_sizes:
return True
- for supported_size in cls.supported_kernel_block_sizes:
+ for supported_size in supported_kernel_block_sizes:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size
diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py
index 6747cf7743b14..125e4e3827747 100644
--- a/vllm/attention/backends/registry.py
+++ b/vllm/attention/backends/registry.py
@@ -43,7 +43,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
- XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = (
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index a8e796a1eab63..f1d57ac50fb9f 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -51,31 +51,6 @@ else:
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
-USE_XFORMERS_OPS = None
-
-
-def check_xformers_availability():
- global USE_XFORMERS_OPS
- if USE_XFORMERS_OPS is not None:
- return USE_XFORMERS_OPS
-
- if current_platform.is_cuda() and current_platform.has_device_capability(100):
- # Xformers FA is not compatible with B200
- USE_XFORMERS_OPS = False
- else:
- try:
- from importlib.util import find_spec
-
- find_spec("xformers.ops")
- USE_XFORMERS_OPS = True
- except ImportError:
- USE_XFORMERS_OPS = False
-
- # the warning only needs to be shown once
- if not USE_XFORMERS_OPS:
- logger.warning("Xformers is not available, falling back.")
-
- return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
@@ -533,7 +508,6 @@ class MultiHeadAttention(nn.Module):
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
@@ -549,12 +523,6 @@ class MultiHeadAttention(nn.Module):
)
)
- if (
- self.attn_backend == AttentionBackendEnum.XFORMERS
- and not check_xformers_availability()
- ):
- self.attn_backend = AttentionBackendEnum.TORCH_SDPA
-
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
@@ -614,12 +582,6 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
-
- out = xops.memory_efficient_attention_forward(
- query, key, value, scale=self.scale
- )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py
index 5b44c7e3e7ec8..068fd0a0eb7d0 100644
--- a/vllm/attention/layers/cross_attention.py
+++ b/vllm/attention/layers/cross_attention.py
@@ -25,15 +25,6 @@ from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__)
-def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
- """Gets the max number of encoder input tokens from the config."""
- sc = vllm_config.scheduler_config
- assert sc and isinstance(sc.max_num_encoder_input_tokens, int), (
- "max_num_encoder_input_tokens must be int for enc-dec models"
- )
- return sc.max_num_encoder_input_tokens
-
-
def _get_cross_slot_mapping(
encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
@@ -93,23 +84,32 @@ def create_cross_attention_backend(
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
- max_encoder_len = _get_max_encoder_len(self.vllm_config)
+ max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
+ # Any computed tokens indicated decode step>1 (no chunked prefill)
+ num_cache_decodes = (
+ (common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
+ )
+ if num_cache_decodes > 0:
+ # CrossAttn KV cache has already been populated on first decoder step,
+ # skip slot_mapping calculation for requests that do not need
+ # reshape_and_cache.
+ num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
+ new_metadata.encoder_seq_lens_cpu = np.where(
+ num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
+ )
- new_metadata.seq_lens = torch.full(
- (new_metadata.num_reqs,),
- max_encoder_len,
- dtype=torch.int32,
- device=self.device,
- )
- new_metadata.seq_lens_cpu = torch.full(
- (new_metadata.num_reqs,),
- max_encoder_len,
- dtype=torch.int32,
- device="cpu",
+ # seq_lens is provided by model runner: initial encoder input length is
+ # needed here to know how many tokens to attend to from the cached
+ # cross-attention KV cache.
+ new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
+ new_metadata.seq_lens_cpu = torch.from_numpy(
+ common_attn_metadata.encoder_seq_lens_cpu
)
+
+ # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping(
- new_metadata.encoder_seq_lens,
+ new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,
diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py
index 67c5f7dbba9c0..af6766bdd1615 100644
--- a/vllm/attention/ops/common.py
+++ b/vllm/attention/ops/common.py
@@ -194,7 +194,6 @@ def _cp_lse_common(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
- assert out.is_contiguous()
return out, lse
diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py
index 06a9f7cd82266..46f8f5117f7a7 100644
--- a/vllm/attention/ops/vit_attn_wrappers.py
+++ b/vllm/attention/ops/vit_attn_wrappers.py
@@ -3,7 +3,7 @@
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
-`to_list` in xformers attn, or `.item()` in flash attention)
+`.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
@@ -19,42 +19,6 @@ import torch.nn.functional as F
from vllm.utils.torch_utils import direct_register_custom_op
-def xformers_attn_seqlens_wrapper(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
- return context_layer
-
-
-def xformers_attn_seqlens_wrapper_fake(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- b, s, h, d = q.shape
- return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
-
-
-direct_register_custom_op(
- op_name="xformers_attn_seqlens_wrapper",
- op_func=xformers_attn_seqlens_wrapper,
- fake_impl=xformers_attn_seqlens_wrapper_fake,
-)
-
-
-def vit_xformers_attn_wrapper(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
-
-
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index e9af08b2316d2..ad19b58aa155c 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -36,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
- return None if backend_name is None else AttentionBackendEnum[backend_name]
+ if backend_name is None:
+ return None
+ if backend_name == "XFORMERS":
+ raise ValueError(
+ "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
+ "details). Please select a supported attention backend."
+ )
+ return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py
index 45ac446a7aedf..1298e4acbd87d 100644
--- a/vllm/benchmarks/sweep/serve.py
+++ b/vllm/benchmarks/sweep/serve.py
@@ -211,6 +211,7 @@ def run_combs(
output_dir: Path,
num_runs: int,
dry_run: bool,
+ links: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@@ -226,6 +227,14 @@ def run_combs(
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
+ should_run = all(
+ serve_key in serve_comb
+ and bench_key in bench_comb
+ and serve_comb[serve_key] == bench_comb[bench_key]
+ for serve_key, bench_key in links
+ )
+ if not should_run:
+ continue
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
@@ -262,6 +271,7 @@ class SweepServeArgs:
num_runs: int
dry_run: bool
resume: str | None
+ link_vars: list[tuple[str, str]] | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@@ -285,7 +295,7 @@ class SweepServeArgs:
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
-
+ link_vars = cls.parse_link_vars(args.link_vars)
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -301,6 +311,7 @@ class SweepServeArgs:
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
+ link_vars=link_vars,
)
@classmethod
@@ -376,8 +387,28 @@ class SweepServeArgs:
"parameter combinations for which there are still no output files.",
)
+ parser.add_argument(
+ "--link-vars",
+ type=str,
+ default="",
+ help=(
+ "Comma-separated list of linked variables between serve and bench, "
+ "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
+ ),
+ )
+
return parser
+ @staticmethod
+ def parse_link_vars(s: str) -> list[tuple[str, str]]:
+ if not s:
+ return []
+ pairs = []
+ for item in s.split(","):
+ a, b = item.split("=")
+ pairs.append((a.strip(), b.strip()))
+ return pairs
+
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -397,6 +428,7 @@ def run_main(args: SweepServeArgs):
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
+ links=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py
index 1e66f21ff6388..2d8dd4c51c7ef 100644
--- a/vllm/compilation/backends.py
+++ b/vllm/compilation/backends.py
@@ -63,13 +63,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
else:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
- else:
- assert compilation_config.backend == "eager", (
- "Custom backends not supported with CompilationMode.VLLM_COMPILE"
- )
-
+ elif compilation_config.backend == "eager":
logger.debug("Using EagerAdaptor")
return EagerAdaptor()
+ else:
+ logger.debug("Using custom backend: %s", compilation_config.backend)
+ compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
+ assert isinstance(compiler, CompilerInterface)
+ return compiler
class CompilerManager:
@@ -545,7 +546,10 @@ class VllmBackend:
self.prefix = prefix or model_tag
# Passes to run on the graph post-grad.
- self.post_grad_pass_manager = PostGradPassManager()
+ self.pass_manager = resolve_obj_by_qualname(
+ current_platform.get_pass_manager_cls()
+ )()
+ self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
@@ -562,24 +566,20 @@ class VllmBackend:
def configure_post_pass(self):
config = self.compilation_config
- self.post_grad_pass_manager.configure(self.vllm_config)
+ self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
- PASS_KEY = "post_grad_custom_post_pass"
- if PASS_KEY in inductor_config:
- if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
+ if self.pass_key in inductor_config:
+ if isinstance(inductor_config[self.pass_key], PostGradPassManager):
# PassManager already added to config, make sure it's correct
- assert (
- inductor_config[PASS_KEY].uuid()
- == self.post_grad_pass_manager.uuid()
- )
+ assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
else:
# Config should automatically wrap all inductor passes
- assert isinstance(inductor_config[PASS_KEY], InductorPass)
- self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
- inductor_config[PASS_KEY] = self.post_grad_pass_manager
+ assert isinstance(inductor_config[self.pass_key], InductorPass)
+ self.pass_manager.add(inductor_config[self.pass_key])
+ inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs
diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py
index 63b7ad7279e37..6297d9f995aa4 100644
--- a/vllm/compilation/caching.py
+++ b/vllm/compilation/caching.py
@@ -116,7 +116,8 @@ class VllmSerializableFunction(SerializableCallable):
the AOT compiled path.
"""
compile_inputs = [
- inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)
+ inp if inp is not None else example_inputs[i]
+ for i, inp in enumerate(fn.example_inputs)
]
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index 11a18c0e6bb78..6d9da1c488c6d 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -24,6 +24,7 @@ from vllm.config import (
get_current_vllm_config,
set_current_vllm_config,
)
+from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -104,6 +105,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
+ shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -161,6 +163,14 @@ def support_torch_compile(
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
+
+ `shape_invariants` is a function that gets compiled right before forward.
+ The function should have the torch._check calls that are needed to set
+ the relationships between different input sizes. For example:
+ torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
+ This enforces constraints on the symbolic shapes without hardcoding
+ specific values. It is needed for some models to avoid data dependent
+ errors.
"""
def cls_decorator_helper(cls: _T) -> _T:
@@ -199,7 +209,11 @@ def support_torch_compile(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
- cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
+ cls,
+ inferred_dynamic_arg_dims,
+ mark_unbacked_dims,
+ enable_if,
+ shape_invariants,
)
if cls is not None:
@@ -242,6 +256,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
+ shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -276,11 +291,12 @@ def _support_torch_compile(
old_init(self, **kwargs)
self.vllm_config = vllm_config
+ self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
- vllm_config.compilation_config.mode
+ self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
@@ -289,29 +305,38 @@ def _support_torch_compile(
if self.do_not_compile:
return
+ self._check_shape_invariants = shape_invariants
+
compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
- def _mark_dynamic_inputs(mod, *args, **kwargs):
+ def _mark_dynamic_inputs(mod, type, *args, **kwargs):
+ def mark_dynamic(arg, dims):
+ if type == DynamicShapesType.UNBACKED:
+ torch._dynamo.decorators.mark_unbacked(arg, dims)
+ else:
+ torch._dynamo.mark_dynamic(arg, dims)
+
sig = inspect.signature(mod.__class__.forward)
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
+
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
- torch._dynamo.mark_dynamic(arg, dims)
+ mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
- torch._dynamo.mark_dynamic(tensor, dims)
+ mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
@@ -338,6 +363,7 @@ def _support_torch_compile(
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
+ ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
@@ -352,6 +378,14 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
+ # Validate that AOT compile is not used with unbacked dynamic
+ # shapes. aot_compile re-allocates backed symbols post dynamo!
+ if ds_type == DynamicShapesType.UNBACKED:
+ raise ValueError(
+ "AOT compilation is not compatible with UNBACKED dynamic shapes. "
+ "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
+ "when VLLM_USE_AOT_COMPILE is enabled."
+ )
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
@@ -401,7 +435,12 @@ def _support_torch_compile(
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
- _mark_dynamic_inputs(self, *args, **kwargs)
+ _mark_dynamic_inputs(
+ self,
+ ds_type,
+ *args,
+ **kwargs,
+ )
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
@@ -417,9 +456,7 @@ def _support_torch_compile(
# properly when any of these files change.
# 1. the file containing the top-level forward function
- self.vllm_config.compilation_config.traced_files.add(
- original_code_object.co_filename
- )
+ self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
@@ -429,7 +466,7 @@ def _support_torch_compile(
def patched_inline_call(self_):
code = self_.f_code
- self.vllm_config.compilation_config.traced_files.add(code.co_filename)
+ self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -445,12 +482,18 @@ def _support_torch_compile(
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
+ # Prepare backed_size_oblivious config patch if needed
+ fx_config_patches = {}
+ if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
+ fx_config_patches["backed_size_oblivious"] = True
+
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
+ torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:
diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py
index 493e57f97f0f4..b120c85bf232e 100644
--- a/vllm/compilation/wrapper.py
+++ b/vllm/compilation/wrapper.py
@@ -6,6 +6,7 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
+from typing import Any
import torch
import torch._C._dynamo.guards
@@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
+ def check_invariants_and_forward(self, *args, **kwargs):
+ assert hasattr(self, "_check_shape_invariants")
+ self._check_shape_invariants(*args, **kwargs)
+
+ return self.forward(*args, **kwargs)
+
def __init__(self):
self.compiled = False
@@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
+ # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
+ from vllm.compilation.decorators import DynamicShapesType
+
+ ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
+ compiled_ptr: Any = self.forward
+ if ds_type == DynamicShapesType.UNBACKED:
+ if envs.VLLM_USE_BYTECODE_HOOK:
+ # reason is that bytecode does this hack torch._dynamo.eval_frame.
+ # remove_from_cache(self.original_code_object()) to force a new
+ # re-compilation.
+ raise ValueError(
+ "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
+ )
+ compiled_ptr = self.check_invariants_and_forward
+
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
@@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
logger.warning(msg)
self._compiled_callable = torch.compile(
- self.forward,
+ compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
diff --git a/vllm/config/cache.py b/vllm/config/cache.py
index 2652c7c06ad0f..ef6928d8ebd5c 100644
--- a/vllm/config/cache.py
+++ b/vllm/config/cache.py
@@ -73,8 +73,8 @@ class CacheConfig:
sliding_window: int | None = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
- enable_prefix_caching: bool | None = None
- """Whether to enable prefix caching. Enabled by default for V1."""
+ enable_prefix_caching: bool = True
+ """Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
index abdae49106120..865d045676d14 100644
--- a/vllm/config/compilation.py
+++ b/vllm/config/compilation.py
@@ -192,6 +192,54 @@ class PassConfig:
self.enable_qk_norm_rope_fusion = False
+class DynamicShapesType(str, enum.Enum):
+ """Types of dynamic shapes handling in torch.compile().
+ see Dynamic shapes and vllm guard dropping in torch_compile.md
+ for more details."""
+
+ BACKED = "backed"
+ """Use backed dynamic shapes. torch.compile() guards on backed dynamic
+ shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
+ without encountering branching on those ranges."""
+
+ UNBACKED = "unbacked"
+ """Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
+ 0/1 specialized, but may throw data dependent errors when branches require
+ their value without explicit unbacked handling."""
+
+ BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
+ """Experimental flag that treats backed symbols as unbacked when explicit
+ unbacked handling is defined."""
+
+
+@config
+@dataclass
+class DynamicShapesConfig:
+ """Configuration to control/debug torch compile dynamic shapes."""
+
+ type: DynamicShapesType = DynamicShapesType.BACKED
+ """Controls the type of dynamic shapes handling to use with torch.compile().
+
+ - BACKED: Default PyTorch behavior with potential guards ignored.
+ - UNBACKED: No guards guaranteed (most sound) but may throw
+ data dependent errors.
+ - BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
+ backed/unbacked.
+ """
+
+ # TODO add a debug mode to fail
+
+ def compute_hash(self) -> str:
+ """
+ Provide a hash for DynamicShapesConfig
+ """
+
+ from vllm.config.utils import get_hash_factors, hash_factors
+
+ factors = get_hash_factors(self, {})
+ return hash_factors(factors)
+
+
@config
@dataclass
class CompilationConfig:
@@ -216,7 +264,6 @@ class CompilationConfig:
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation:
- - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
@@ -283,9 +330,9 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
- compilation mode is 3, the backend is used for the piecewise compilation
- (it sees a part of the graph). The backend can not be custom for compilation
- mode 3, i.e. the backend must be either eager or inductor. Furthermore,
+ compilation mode is 3, the backend supports both whole graph and piecewise
+ compilation, available backends include eager, inductor, and custom backends,
+ the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
@@ -300,7 +347,7 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
- disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
+ disabled when running with Inductor: mode>=VLLM_COMPILE and backend="inductor".
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
@@ -322,35 +369,19 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
- Currently, this only works for `Qwen2_5_vl` on selected platforms.
+ Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
# Inductor capture
- use_inductor: bool | None = None
- """
- Whether to use inductor compilation.
-
- This flag is deprecated and will be removed in the next release 0.12.0.
- Please use the 'backend' option instead.
-
- - False: inductor compilation is not used. graph runs in eager
- (custom_ops enabled by default).
- - True: inductor compilation is used (custom_ops disabled by default).
- One graph for symbolic shape and one graph per size in compile_sizes
- are compiled using configurations in inductor_compile_config.
-
- This setting is ignored if mode512) that would
greatly increase startup time with limited performance benefit.
"""
+
+ dynamic_shapes_config: DynamicShapesConfig = field(
+ default_factory=DynamicShapesConfig
+ )
+ """Configuration for dynamic shapes options"""
+
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
+
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
@@ -530,6 +568,7 @@ class CompilationConfig:
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
+
factors["pass_config"] = self.pass_config.compute_hash()
return hash_factors(factors)
@@ -701,16 +740,8 @@ class CompilationConfig:
f"Invalid backend for piecewise compilation: {self.backend}"
)
- if self.use_inductor is not None:
- logger.warning_once(
- "The 'use_inductor' flag is deprecated and will be "
- "removed in the next release (v0.12.0). "
- "Please use the 'backend' option instead.",
- )
- self.backend = "inductor" if self.use_inductor else "eager"
-
if self.backend == "":
- self.backend = current_platform.simple_compile_backend
+ self.backend = current_platform.get_compile_backend()
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
@@ -742,9 +773,7 @@ class CompilationConfig:
assert self.mode == CompilationMode.VLLM_COMPILE
if self.backend not in ["eager", "inductor"]:
- raise ValueError(
- f"Invalid backend for piecewise compilation: {self.backend}"
- )
+ logger.info("Using OOT custom backend for compilation.")
from vllm.compilation.backends import VllmBackend
@@ -950,14 +979,18 @@ class CompilationConfig:
)
)
+ if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size:
+ # if one valid but would be round_down use that
+ rounded_sizes = [multiple_of]
+
if len(rounded_sizes) == 0:
- logger.warning(
- "No valid cudagraph sizes after rounding to multiple of "
- " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
- " or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
- multiple_of,
+ raise ValueError(
+ f"No valid cudagraph sizes after rounding to multiple of {multiple_of} "
+ f"(num_speculative_tokens + 1 or tp if sequence parallelism is enabled)"
+ f" please adjust num_speculative_tokens ({uniform_decode_query_len - 1}"
+ f") or max_cudagraph_capture_size ({self.max_cudagraph_capture_size})"
+ f" or cudagraph_capture_sizes ({self.cudagraph_capture_sizes})"
)
- return
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 49688e17cf932..ce5e824da5c22 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -33,12 +33,18 @@ from vllm.transformers_utils.config import (
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_mrope,
+ uses_xdrope_dim,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
-from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
+from vllm.transformers_utils.utils import (
+ is_gguf,
+ is_remote_gguf,
+ maybe_model_redirect,
+ split_remote_gguf,
+)
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
@@ -293,9 +299,6 @@ class ModelConfig:
pooler_config: PoolerConfig | None = None
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
- override_pooler_config: dict | PoolerConfig | None = None
- """[DEPRECATED] Use `pooler_config` instead. This field will be removed in
- v0.12.0 or v1.0.0, whichever is sooner."""
# Multimodal config and init vars
multimodal_config: MultiModalConfig | None = None
@@ -353,7 +356,6 @@ class ModelConfig:
"logits_processors",
"io_processor_plugin",
"pooler_config",
- "override_pooler_config",
"multimodal_config",
"limit_mm_per_prompt",
"media_io_kwargs",
@@ -439,7 +441,8 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
- if check_gguf_file(self.model):
+ # Check if this is a GGUF model (either local file or remote GGUF)
+ if is_gguf(self.model):
raise ValueError(
"Using a tokenizer is mandatory when loading a GGUF model. "
"Please specify the tokenizer path or name using the "
@@ -585,16 +588,26 @@ class ModelConfig:
else: # task == "auto"
pass
else:
- debug_info = {
- "architectures": architectures,
- "is_generative_model": is_generative_model,
- "is_pooling_model": is_pooling_model,
- }
- raise AssertionError(
- "The model should be a generative or "
- "pooling model when task is set to "
- f"{self.task!r}. Found: {debug_info}"
- )
+ # Neither generative nor pooling model - try to convert if possible
+ if is_pooling_task:
+ runner = "pooling"
+ convert = _task_to_convert(self.task)
+ msg_hint = (
+ "Please replace this option with `--runner pooling "
+ f"--convert {convert}` to continue using this model "
+ "as a pooling model."
+ )
+ else:
+ debug_info = {
+ "architectures": architectures,
+ "is_generative_model": is_generative_model,
+ "is_pooling_model": is_pooling_model,
+ }
+ raise AssertionError(
+ "The model should be a generative or "
+ "pooling model when task is set to "
+ f"{self.task!r}. Found: {debug_info}"
+ )
self.runner = runner
self.convert = convert
@@ -631,18 +644,6 @@ class ModelConfig:
# Init pooler config if needed
if self.runner_type == "pooling":
- if self.override_pooler_config is not None:
- logger.warning_once(
- "`override_pooler_config` is deprecated and will be "
- "removed in v0.12.0 or v1.0.0, whichever is sooner. "
- "Please use `pooler_config` instead."
- )
-
- if isinstance(self.override_pooler_config, dict):
- self.pooler_config = PoolerConfig(**self.override_pooler_config)
- else:
- self.pooler_config = self.override_pooler_config
-
if self.pooler_config is None:
self.pooler_config = PoolerConfig()
@@ -821,7 +822,10 @@ class ModelConfig:
self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self):
- return get_sentence_transformer_tokenizer_config(self.model, self.revision)
+ model = self.model
+ if is_remote_gguf(model):
+ model, _ = split_remote_gguf(model)
+ return get_sentence_transformer_tokenizer_config(model, self.revision)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
@@ -1605,6 +1609,10 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
+ @property
+ def uses_xdrope_dim(self) -> int:
+ return uses_xdrope_dim(self.hf_config)
+
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py
index 9f62b35ed515c..00a81a319bf72 100644
--- a/vllm/config/multimodal.py
+++ b/vllm/config/multimodal.py
@@ -173,6 +173,12 @@ class MultiModalConfig:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
+ if isinstance(value, str) and value.upper() == "XFORMERS":
+ raise ValueError(
+ "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
+ "details). Please select a supported attention backend."
+ )
+
if value is None or isinstance(value, AttentionBackendEnum):
return value
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index 4b0236d8de3f5..913e97250d3d3 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -60,6 +60,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
+ use_async: bool = False
+ """
+ Whether to use non-blocking EPLB.
+ """
@config
@@ -137,22 +141,6 @@ class ParallelConfig:
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- num_redundant_experts: int | None = None
- """`num_redundant_experts` is deprecated and has been replaced with
- `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
- Please use `eplb_config.num_redundant_experts` instead."""
- eplb_window_size: int | None = None
- """`eplb_window_size` is deprecated and has been replaced with
- `eplb_config.window_size`. This will be removed in v0.12.0.
- Please use `eplb_config.window_size` instead."""
- eplb_step_interval: int | None = None
- """`eplb_step_interval` is deprecated and has been replaced with
- `eplb_config.step_interval`. This will be removed in v0.12.0.
- Please use `eplb_config.step_interval` instead."""
- eplb_log_balancedness: bool | None = None
- """`eplb_log_balancedness` is deprecated and has been replaced with
- `eplb_config.log_balancedness`. This will be removed in v0.12.0.
- Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model
@@ -512,40 +500,6 @@ class ParallelConfig:
"--all2all-backend command-line argument instead."
)
- # Forward deprecated fields to their new location
- if self.num_redundant_experts is not None:
- self.eplb_config.num_redundant_experts = self.num_redundant_experts
- logger.warning_once(
- "num_redundant_experts is deprecated and has been replaced "
- "with eplb_config.num_redundant_experts. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_window_size is not None:
- self.eplb_config.window_size = self.eplb_window_size
- logger.warning_once(
- "eplb_window_size is deprecated and has been replaced "
- "with eplb_config.window_size. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_step_interval is not None:
- self.eplb_config.step_interval = self.eplb_step_interval
- logger.warning_once(
- "eplb_step_interval is deprecated and has been replaced "
- "with eplb_config.step_interval. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_log_balancedness is not None:
- self.eplb_config.log_balancedness = self.eplb_log_balancedness
- logger.warning_once(
- "eplb_log_balancedness is deprecated and has been replaced "
- "with eplb_config.log_balancedness. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
-
# Continue with the rest of the initialization
self.world_size = (
self.pipeline_parallel_size
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index d64e315b4fe39..8a3599416bc72 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -96,7 +96,7 @@ class VllmConfig:
"""`torch.compile` and cudagraph capture configuration for the model.
As a shorthand, one can append compilation arguments via
- -0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`).
+ -0.parameter=argument such as `-O.mode=3` (same as `-O='{"mode":3}'`).
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py
index eb1f173b11925..7a049b003cf73 100644
--- a/vllm/distributed/device_communicators/symm_mem.py
+++ b/vllm/distributed/device_communicators/symm_mem.py
@@ -131,7 +131,7 @@ class SymmMemCommunicator:
return None
if out is None:
out = torch.empty_like(inp)
- self.buffer[: inp.numel()].copy_(inp.view(-1))
+ self.buffer[: inp.numel()].copy_(inp.reshape(-1))
# Determine which algorithm to use
use_multimem = False
diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py
new file mode 100644
index 0000000000000..e4b4fc92eeaaa
--- /dev/null
+++ b/vllm/distributed/eplb/async_worker.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+The async worker that transfers experts in the background.
+"""
+
+import asyncio
+import threading
+from typing import TYPE_CHECKING
+
+import torch
+from torch.distributed import ProcessGroup
+
+from vllm.distributed.parallel_state import get_ep_group
+from vllm.logger import init_logger
+
+from .rebalance_execute import transfer_layer
+
+if TYPE_CHECKING:
+ from .eplb_state import EplbState
+
+logger = init_logger(__name__)
+
+
+def start_async_worker(
+ state: "EplbState",
+ rank_mapping: dict[int, int] | None = None,
+ is_profile: bool = False,
+) -> threading.Thread:
+ ep_group = get_ep_group().device_group
+ rank = ep_group.rank()
+ device_index = state.cuda_device_index
+
+ def thread_target() -> None:
+ assert device_index is not None
+ torch.cuda.set_device(device_index)
+ cuda_stream = torch.cuda.Stream(device=device_index)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(
+ transfer_run_periodically(
+ state=state,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ rank_mapping=rank_mapping,
+ cuda_stream=cuda_stream,
+ )
+ )
+ except Exception as exc: # pragma: no cover - diagnostic path
+ logger.exception("async loop error (Rank %d): %s", rank, str(exc))
+ finally:
+ loop.close()
+
+ thread = threading.Thread(target=thread_target, daemon=True)
+ thread.start()
+ return thread
+
+
+async def transfer_run_periodically(
+ state: "EplbState",
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ rank_mapping: dict[int, int] | None = None,
+ cuda_stream: torch.cuda.Stream = None,
+) -> None:
+ while True:
+ await asyncio.to_thread(state.rearrange_event.wait)
+ logger.info("async worker woke up for EPLB transfer")
+
+ for model_state in state.model_states.values():
+ if not model_state.is_async_enabled:
+ continue
+ current_num_layers = model_state.model.num_moe_layers
+ while (
+ model_state.rebalanced
+ and model_state.layer_to_transfer < current_num_layers
+ ):
+ if (
+ not model_state.ep_buffer_ready
+ and model_state.rebalanced
+ and model_state.new_physical_to_logical_map is not None
+ ):
+ await asyncio.to_thread(model_state.buffer_lock.acquire)
+ try:
+ if model_state.layer_to_transfer >= current_num_layers:
+ break
+
+ (
+ model_state.is_unchanged,
+ model_state.is_received_locally,
+ model_state.experts_recv_loc,
+ ) = await transfer_layer(
+ old_global_expert_indices=model_state.physical_to_logical_map,
+ new_global_expert_indices=model_state.new_physical_to_logical_map,
+ expert_weights=model_state.model.expert_weights,
+ expert_weights_buffer=model_state.expert_buffer,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ layer=model_state.layer_to_transfer,
+ cuda_stream=cuda_stream,
+ rank_mapping=rank_mapping,
+ )
+ event = torch.cuda.Event(blocking=False)
+ cuda_stream.record_event(event)
+ model_state.buffer_ready_event = event
+ model_state.ep_buffer_ready = 1
+ finally:
+ model_state.buffer_lock.release()
+ else:
+ if not model_state.rebalanced:
+ break
+ await asyncio.sleep(0.001)
+
+ state.rearrange_event.clear()
diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py
index 526d3ceac7b8f..9f8798a96a2fc 100644
--- a/vllm/distributed/eplb/eplb_state.py
+++ b/vllm/distributed/eplb/eplb_state.py
@@ -26,6 +26,7 @@ MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""
+import threading
import time
from collections.abc import Sequence
from dataclasses import dataclass
@@ -43,8 +44,9 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
+from .async_worker import start_async_worker
from .rebalance_algo import rebalance_experts
-from .rebalance_execute import rearrange_expert_weights_inplace
+from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
logger = init_logger(__name__)
@@ -132,6 +134,74 @@ class EplbModelState:
"""
model_name: str
model: MixtureOfExperts
+ expert_buffer: list[torch.Tensor]
+ """
+ The buffer to store the expert weights during transfer.
+ """
+ buffer_lock: threading.Lock
+ """
+ The lock to protect the expert buffer.
+ """
+ buffer_ready_event: torch.cuda.Event | None
+ """
+ CUDA event recorded when the async worker finishes filling the buffer.
+ The main thread waits on this before consuming the buffer.
+ """
+ ep_buffer_ready: int
+ """
+ The flag indicates whether the expert buffer is ready for transfer.
+ 0 or 1.
+ """
+ layer_to_transfer: int
+ """
+ The layer index to transfer in async mode.
+ """
+ rebalanced: bool
+ """
+ The flag indicates whether the experts rebalance have been computed.
+ """
+ pending_global_ready_check: bool
+ """
+ Whether the async EPLB needs to poll peers for buffer readiness.
+ """
+ is_unchanged: list[bool]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ is_received_locally: list[bool]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ experts_recv_loc: dict[int, int]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ is_async_enabled: bool
+ """
+ The flag indicates whether the EPLB is running in async mode.
+ """
+ cuda_device_index: int | None
+ """
+ CUDA device index for the async EPLB worker thread.
+ """
+ new_physical_to_logical_map: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as physical_to_logical_map
+ """
+ new_logical_to_physical_map: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as logical_to_physical_map
+ """
+ new_logical_replica_count: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as logical_replica_count
+ """
class EplbState:
@@ -164,12 +234,31 @@ class EplbState:
Otherwise, the rearrangement will hang at collective
communication calls.
"""
- self.expert_rearrangement_step: int = 0
+ self.expert_rearrangement_step_interval: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
- self.expert_rearrangement_step_interval: int = 0
+ self.is_async: bool = False
+ """
+ The flag indicates whether the EPLB is running in async mode.
+ """
+ self.rearrange_event = threading.Event()
+ """
+ Event to signal when a new rearrangement is needed for the async thread.
+ """
+ self.async_worker: threading.Thread | None = None
+ """
+ Background thread handling async transfers.
+ """
+ self.cuda_device_index: int | None = None
+ """
+ CUDA device index for the async EPLB worker thread.
+ """
+ if self.device.type == "cuda":
+ self.cuda_device_index = self.device.index
+ if self.cuda_device_index is None and torch.cuda.is_available():
+ self.cuda_device_index = torch.cuda.current_device()
@staticmethod
def build_initial_global_physical_to_logical_map(
@@ -239,6 +328,8 @@ class EplbState:
Build the initial EPLB state.
"""
self.validate_ep_configuration(model)
+ self.is_async = self.parallel_config.eplb_config.use_async
+
physical_to_logical_map_list = (
EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
@@ -368,7 +459,12 @@ class EplbState:
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
+ else:
+ new_physical_to_logical_map = None
+ new_logical_to_physical_map = None
+
+ new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
@@ -385,15 +481,33 @@ class EplbState:
)
self.expert_rearrangement_step = 0
- self.model_states[model_config.compute_hash()] = EplbModelState(
- physical_to_logical_map,
- logical_to_physical_map,
- logical_replica_count,
- expert_load_pass,
- expert_load_window,
- model_config.model,
- model,
+ expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
+
+ model_state = EplbModelState(
+ physical_to_logical_map=physical_to_logical_map,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ expert_load_pass=expert_load_pass,
+ expert_load_window=expert_load_window,
+ model_name=model_config.model,
+ model=model,
+ expert_buffer=expert_buffer,
+ buffer_lock=threading.Lock(),
+ buffer_ready_event=None,
+ ep_buffer_ready=0,
+ layer_to_transfer=0,
+ rebalanced=False,
+ pending_global_ready_check=False,
+ is_unchanged=[],
+ is_received_locally=[],
+ experts_recv_loc={},
+ is_async_enabled=self.is_async,
+ cuda_device_index=self.cuda_device_index,
+ new_physical_to_logical_map=new_physical_to_logical_map,
+ new_logical_to_physical_map=new_logical_to_physical_map,
+ new_logical_replica_count=new_logical_replica_count,
)
+ self.model_states[model_config.compute_hash()] = model_state
def step(
self,
@@ -420,7 +534,7 @@ class EplbState:
- `max_tokens`: The maximum load across ranks.
- `balancedness`: The ratio of average load to maximum load.
"""
-
+ ep_group = get_ep_group().device_group
if is_profile:
self.rearrange(is_profile=True)
return
@@ -488,7 +602,49 @@ class EplbState:
# rearrangement step and perform rearrangement to ensure all ranks are
# performing collective communication.
self.expert_rearrangement_step += 1
+
+ if self.is_async:
+ for eplb_model_state in self.model_states.values():
+ if not eplb_model_state.is_async_enabled:
+ continue
+
+ all_ranks_buffer_ready = False
+ if eplb_model_state.pending_global_ready_check:
+ all_ranks_buffer_ready = self._all_ranks_buffer_ready(
+ eplb_model_state
+ )
+ if (
+ eplb_model_state.is_async_enabled
+ and eplb_model_state.ep_buffer_ready
+ and all_ranks_buffer_ready
+ ):
+ self.move_to_workspace(
+ model_state=eplb_model_state,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ )
+ if (
+ eplb_model_state.layer_to_transfer
+ >= eplb_model_state.model.num_moe_layers
+ ):
+ self.post_eplb(eplb_model_state, is_profile)
+ eplb_model_state.rebalanced = False
+ eplb_model_state.layer_to_transfer = 0
+ eplb_model_state.pending_global_ready_check = False
+ logger.info(
+ "finish async transfer for model %s rank %d layer %d",
+ eplb_model_state.model_name,
+ ep_group.rank(),
+ eplb_model_state.model.num_moe_layers,
+ )
+
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
+ if any(
+ eplb_model_state.is_async_enabled and eplb_model_state.rebalanced
+ for eplb_model_state in self.model_states.values()
+ ):
+ # Still performing asynchronous rearrangement
+ return
self.expert_rearrangement_step = 0
self.rearrange()
@@ -524,7 +680,11 @@ class EplbState:
if is_main_rank:
torch.cuda.synchronize()
time_start = time.perf_counter()
- logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
+ logger.info(
+ "Rearranging experts %s %s...",
+ "(async mode)" if self.is_async else "sync mode",
+ "(profile)" if is_profile else "",
+ )
if global_expert_loads is None:
# Map the physical expert load to global logical experts
@@ -593,6 +753,7 @@ class EplbState:
model = eplb_model_state.model
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
+
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
@@ -608,7 +769,7 @@ class EplbState:
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
- self.num_nodes = 1
+ num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
@@ -631,60 +792,216 @@ class EplbState:
num_gpus,
)
- # Update expert weights
- rearrange_expert_weights_inplace(
- eplb_model_state.physical_to_logical_map,
- new_physical_to_logical_map,
- eplb_model_state.model.expert_weights,
- ep_group,
- is_profile,
- rank_mapping,
- )
+ if not eplb_model_state.is_async_enabled or is_profile:
+ # Update expert weights
+ rearrange_expert_weights_inplace(
+ eplb_model_state.physical_to_logical_map,
+ new_physical_to_logical_map,
+ eplb_model_state.model.expert_weights,
+ ep_group,
+ is_profile,
+ rank_mapping,
+ )
- if not is_profile:
- if (
- eplb_model_state.physical_to_logical_map.shape[1]
- != new_physical_to_logical_map.shape[1]
- ):
- eplb_model_state.physical_to_logical_map = (
- new_physical_to_logical_map.to(
- eplb_model_state.physical_to_logical_map.device
+ if not is_profile:
+ if (
+ eplb_model_state.physical_to_logical_map.shape[1]
+ != new_physical_to_logical_map.shape[1]
+ ):
+ eplb_model_state.physical_to_logical_map = (
+ new_physical_to_logical_map.to(
+ eplb_model_state.physical_to_logical_map.device
+ )
)
+ else:
+ eplb_model_state.physical_to_logical_map.copy_(
+ new_physical_to_logical_map
+ )
+ max_physical_slots = new_logical_to_physical_map.shape[-1]
+ assert (
+ max_physical_slots
+ <= eplb_model_state.logical_to_physical_map.shape[-1]
)
- else:
- eplb_model_state.physical_to_logical_map.copy_(
- new_physical_to_logical_map
+ new_logical_to_physical_map = torch.nn.functional.pad(
+ new_logical_to_physical_map,
+ (
+ 0,
+ eplb_model_state.logical_to_physical_map.shape[-1]
+ - max_physical_slots,
+ ),
+ value=-1,
)
- max_physical_slots = new_logical_to_physical_map.shape[-1]
- assert (
- max_physical_slots
- <= eplb_model_state.logical_to_physical_map.shape[-1]
- )
- new_logical_to_physical_map = torch.nn.functional.pad(
+ eplb_model_state.logical_to_physical_map.copy_(
+ new_logical_to_physical_map
+ )
+ eplb_model_state.logical_replica_count.copy_(
+ new_logical_replica_count
+ )
+ if is_main_rank:
+ assert time_start is not None
+ torch.cuda.synchronize()
+ time_end = time.perf_counter()
+ logger.info(
+ "Rearranged experts%sin %.2f seconds.",
+ " (profile) " if is_profile else " ",
+ time_end - time_start,
+ )
+ else:
+ device = eplb_model_state.physical_to_logical_map.device
+ new_physical = new_physical_to_logical_map.to(device)
+ max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
+ padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
- (
- 0,
- eplb_model_state.logical_to_physical_map.shape[-1]
- - max_physical_slots,
- ),
+ (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
+ ).to(eplb_model_state.logical_to_physical_map.device)
+ new_replica = new_logical_replica_count.to(
+ eplb_model_state.logical_replica_count.device
)
- eplb_model_state.logical_to_physical_map.copy_(
- new_logical_to_physical_map
- )
- eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
- if is_main_rank:
- assert time_start is not None
- torch.cuda.synchronize()
- time_end = time.perf_counter()
- logger.info(
- "Rearranged experts%sin %.2f seconds.",
- " (profile) " if is_profile else " ",
- time_end - time_start,
- )
+ eplb_model_state.new_physical_to_logical_map = new_physical
+ eplb_model_state.new_logical_to_physical_map = padded_logical
+ eplb_model_state.new_logical_replica_count = new_replica
+
+ eplb_model_state.rebalanced = True
+ eplb_model_state.layer_to_transfer = 0
+ eplb_model_state.pending_global_ready_check = True
+
+ # Signal async thread to start transferring layers
+ if self.is_async and (not is_profile):
+ self.rearrange_event.set()
return None
+ def start_async_loop(
+ self,
+ rank_mapping: dict[int, int] | None = None,
+ is_profile: bool = False,
+ ):
+ if not self.is_async:
+ return
+ if self.async_worker is None:
+ self.async_worker = start_async_worker(
+ self,
+ rank_mapping=rank_mapping,
+ is_profile=is_profile,
+ )
+
+ def _update_layer_mapping_from_new(
+ self, model_state: EplbModelState, layer: int
+ ) -> None:
+ if (
+ model_state.new_physical_to_logical_map is None
+ or model_state.new_logical_to_physical_map is None
+ or model_state.new_logical_replica_count is None
+ ):
+ return
+
+ target_device = model_state.physical_to_logical_map.device
+ new_physical = model_state.new_physical_to_logical_map
+ if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
+ model_state.physical_to_logical_map = new_physical.to(target_device)
+ else:
+ model_state.physical_to_logical_map[layer].copy_(
+ new_physical[layer].to(target_device)
+ )
+
+ logical_device = model_state.logical_to_physical_map.device
+ new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
+ max_slots = model_state.logical_to_physical_map.shape[-1]
+ slot_delta = max_slots - new_logical.shape[-1]
+ if slot_delta > 0:
+ new_logical = torch.nn.functional.pad(
+ new_logical, (0, slot_delta), value=-1
+ )
+ model_state.logical_to_physical_map[layer].copy_(new_logical)
+
+ replica_device = model_state.logical_replica_count.device
+ model_state.logical_replica_count[layer].copy_(
+ model_state.new_logical_replica_count[layer].to(replica_device)
+ )
+
+ def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
+ parallel_state = get_ep_group()
+ cpu_group = getattr(parallel_state, "cpu_group", None)
+ if cpu_group is not None and cpu_group.size() > 1:
+ flag = torch.tensor(
+ (int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu"
+ )
+ all_reduce(flag, group=cpu_group)
+ return int(flag.item()) == cpu_group.size()
+
+ device_group = parallel_state.device_group
+ if device_group.size() <= 1:
+ return bool(model_state.ep_buffer_ready)
+
+ device = getattr(
+ parallel_state, "device", model_state.physical_to_logical_map.device
+ )
+ flag = torch.tensor(
+ (int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device
+ )
+ all_reduce(flag, group=device_group)
+ return int(flag.item()) == device_group.size()
+
+ def move_to_workspace(
+ self,
+ model_state: EplbModelState,
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ ):
+ if not model_state.buffer_lock.acquire(blocking=False):
+ return
+ try:
+ assert model_state.new_physical_to_logical_map is not None
+ device_index = model_state.cuda_device_index or self.cuda_device_index
+ if model_state.buffer_ready_event is not None and device_index is not None:
+ stream = torch.cuda.current_stream(device=device_index)
+ stream.wait_event(model_state.buffer_ready_event)
+ model_state.buffer_ready_event = None
+ move_from_buffer(
+ expert_weights=model_state.model.expert_weights[
+ model_state.layer_to_transfer
+ ],
+ expert_weights_buffer=model_state.expert_buffer,
+ is_unchanged=model_state.is_unchanged,
+ is_received_locally=model_state.is_received_locally,
+ experts_recv_loc=model_state.experts_recv_loc,
+ new_indices=model_state.new_physical_to_logical_map[
+ model_state.layer_to_transfer
+ ].tolist(),
+ ep_group=ep_group,
+ )
+ transferred_layer = model_state.layer_to_transfer
+ self._update_layer_mapping_from_new(model_state, transferred_layer)
+ # After the main thread consumes, advance layer_to_transfer
+ model_state.layer_to_transfer += 1
+ model_state.ep_buffer_ready = 0
+ logger.info(
+ "model %s successfully move_to_workspace layer %d",
+ model_state.model_name,
+ transferred_layer,
+ )
+ finally:
+ try:
+ model_state.buffer_lock.release()
+ except Exception as e:
+ logger.error(
+ "Rank %d: buffer_lock release failed in move_to_workspace: %s",
+ ep_group.rank(),
+ str(e),
+ )
+
+ def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None:
+ assert model_state.new_physical_to_logical_map is not None
+ assert model_state.new_logical_to_physical_map is not None
+ assert model_state.new_logical_replica_count is not None
+ if not is_profile:
+ for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
+ self._update_layer_mapping_from_new(model_state, layer_idx)
+ model_state.new_physical_to_logical_map = None
+ model_state.new_logical_to_physical_map = None
+ model_state.new_logical_replica_count = None
+
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py
index 5c1efbaf03bab..376dad8a72ef1 100644
--- a/vllm/distributed/eplb/rebalance_execute.py
+++ b/vllm/distributed/eplb/rebalance_execute.py
@@ -100,18 +100,19 @@ def get_ep_ranks_with_expert(
return ranks_to_send, ranks_to_recv_actual
-def shuffle_layer(
+def move_to_buffer(
num_local_experts: int,
- ep_rank: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
+ cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
-) -> None:
+) -> tuple[list[bool], list[bool], dict[int, int]]:
"""
Perform expert weights rearrangement of one layer.
"""
+ ep_rank = ep_group.rank()
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
@@ -137,7 +138,8 @@ def shuffle_layer(
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- buffer[dst].copy_(weight[src])
+ with torch.cuda.stream(cuda_stream):
+ buffer[dst].copy_(weight[src], non_blocking=True)
p2p_ops: list[P2POp] = []
@@ -225,25 +227,115 @@ def shuffle_layer(
]
# 4. Execute the P2P operations. The real communication happens here.
- if p2p_ops:
+ if p2p_ops and cuda_stream is not None:
+ with torch.cuda.stream(cuda_stream):
+ reqs = batch_isend_irecv(p2p_ops)
+ for req in reqs:
+ req.wait()
+ elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
+ # wait for the communication to finish
+ return is_unchanged, is_received_locally, experts_recv_loc
+
+
+def move_from_buffer(
+ expert_weights: Iterable[torch.Tensor],
+ expert_weights_buffer: list[torch.Tensor],
+ is_unchanged: list[bool],
+ is_received_locally: list[bool],
+ experts_recv_loc: dict[int, int],
+ new_indices: Sequence[int],
+ ep_group: ProcessGroup,
+) -> None:
+ ep_rank = ep_group.rank()
+ num_local_experts = len(is_unchanged)
+
+ local2global = partial(
+ idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
+ )
- # 5. Copy the weights from the buffer back to the original weights.
for dst in range(num_local_experts):
if is_unchanged[dst]:
continue
if is_received_locally[dst]:
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- weight[dst].copy_(buffer[dst])
+ weight[dst].copy_(buffer[dst], non_blocking=True)
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- weight[dst].copy_(buffer[src])
+ weight[dst].copy_(buffer[src], non_blocking=True)
+
+
+async def transfer_layer(
+ old_global_expert_indices: torch.Tensor,
+ new_global_expert_indices: torch.Tensor,
+ expert_weights: Sequence[Iterable[torch.Tensor]],
+ expert_weights_buffer: Sequence[torch.Tensor],
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ layer: int = 0,
+ cuda_stream: torch.cuda.Stream | None = None,
+ rank_mapping: dict[int, int] | None = None,
+) -> tuple[list[bool], list[bool], dict[int, int]]:
+ """
+ Rearranges the expert weights in place according to the new expert indices.
+
+ The value of the indices arguments are logical indices of the experts,
+ while keys are physical.
+
+ Args:
+ old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ expert_weights: A sequence of shape (num_moe_layers)(weight_count)
+ of tensors of shape (num_local_physical_experts, hidden_size_i).
+ For example, a linear layer may have up and down projection,
+ so weight_count = 2. Each weight's hidden size can be different.
+ ep_group: The device process group for expert parallelism.
+ is_profile (bool): If `True`, do not perform any actual weight copy.
+ This is used during profile run, where we only perform dummy
+ communications to reserve enough memory for the buffers.
+ """
+ ep_size = ep_group.size()
+ if rank_mapping is not None:
+ if len(rank_mapping) == ep_group.size():
+ # scale down
+ new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
+ new_global_expert_indices,
+ rank_mapping,
+ )
+ else:
+ # scale up
+ old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
+ old_global_expert_indices,
+ rank_mapping,
+ ep_group.size(),
+ )
+
+ assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
+ num_moe_layers, num_physical_experts = old_global_expert_indices.shape
+ assert len(expert_weights) == num_moe_layers
+ num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
+ assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
+ assert num_physical_experts == ep_size * num_local_physical_experts
+ # A buffer to hold the expert weights in one layer during the exchange.
+ # NOTE: Currently we assume the same weights across different layers
+ # have the same shape.
+
+ is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
+ num_local_experts=num_local_physical_experts,
+ old_indices=old_global_expert_indices[layer].tolist(),
+ new_indices=new_global_expert_indices[layer].tolist(),
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ cuda_stream=cuda_stream,
+ ep_group=ep_group,
+ )
+ return is_unchanged, is_received_locally, experts_recv_loc
def rearrange_expert_weights_inplace(
@@ -296,7 +388,6 @@ def rearrange_expert_weights_inplace(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
- ep_rank = ep_group.rank()
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
@@ -329,14 +420,24 @@ def rearrange_expert_weights_inplace(
torch.cuda.synchronize()
for layer in range(num_moe_layers):
- shuffle_layer(
- num_local_physical_experts,
- ep_rank,
- old_global_expert_indices_cpu[layer].tolist(),
- new_global_expert_indices_cpu[layer].tolist(),
- expert_weights[layer],
- expert_weights_buffer,
- ep_group,
+ is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
+ num_local_experts=num_local_physical_experts,
+ old_indices=old_global_expert_indices_cpu[layer].tolist(),
+ new_indices=new_global_expert_indices_cpu[layer].tolist(),
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ cuda_stream=None,
+ ep_group=ep_group,
+ )
+
+ move_from_buffer(
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ is_unchanged=is_unchanged,
+ is_received_locally=is_received_locally,
+ experts_recv_loc=experts_recv_loc,
+ new_indices=new_global_expert_indices[layer].tolist(),
+ ep_group=ep_group,
)
@@ -428,4 +529,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices
-__all__ = ["rearrange_expert_weights_inplace"]
+__all__ = ["transfer_layer", "move_from_buffer"]
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
index ab2eeed9f6b8a..6acfb73997f25 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@@ -310,7 +310,6 @@ class LMCacheMPWorkerAdapter:
request_id,
result,
)
- logger.info("Retrieve request for request_id=%s finished", request_id)
# Remove the finished requests from the tracking dicts
for request_id in finished_stores:
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
index 22ddabbf1e352..d1d3e475cc889 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
@@ -469,9 +469,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
ops.append(meta.op)
if len(request_ids) > 0:
- logger.info(
- "HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s", request_ids
- )
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event
)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index 7c0911240493c..493938d4aad92 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -4,7 +4,6 @@ import contextlib
import copy
import logging
import math
-import os
import queue
import threading
import time
@@ -810,9 +809,6 @@ class NixlConnectorWorker:
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
- # TODO temporary, once nixl allows for telemetry flag in config
- # (next release), we can remove this env var.
- os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
@@ -828,10 +824,11 @@ class NixlConnectorWorker:
if nixl_agent_config is None:
config = None
else:
+ # Enable telemetry by default for NIXL 0.7.1 and above.
config = (
- nixl_agent_config(backends=self.nixl_backends)
+ nixl_agent_config(backends=self.nixl_backends, capture_telemetry=True)
if len(non_ucx_backends) > 0
- else nixl_agent_config(num_threads=num_threads)
+ else nixl_agent_config(num_threads=num_threads, capture_telemetry=True)
)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 888f57b1ac1df..696ff3a1f4024 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -29,7 +29,7 @@ import regex as re
import torch
from pydantic import TypeAdapter, ValidationError
from pydantic.fields import FieldInfo
-from typing_extensions import TypeIs, deprecated
+from typing_extensions import TypeIs
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
@@ -77,7 +77,7 @@ from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
from vllm.config.utils import get_field
-from vllm.logger import init_logger
+from vllm.logger import init_logger, suppress_logging
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
@@ -86,7 +86,7 @@ from vllm.transformers_utils.config import (
is_interleaved,
maybe_override_with_speculators,
)
-from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage
+from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
@@ -247,11 +247,13 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
default = field.default
# Handle pydantic.Field defaults
if isinstance(default, FieldInfo):
- default = (
- default.default
- if default.default_factory is None
- else default.default_factory()
- )
+ if default.default_factory is None:
+ default = default.default
+ else:
+ # VllmConfig's Fields have default_factory set to config classes.
+ # These could emit logs on init, which would be confusing.
+ with suppress_logging():
+ default = default.default_factory()
elif field.default_factory is not MISSING:
default = field.default_factory()
@@ -425,7 +427,7 @@ class EngineArgs:
ParallelConfig.max_parallel_loading_workers
)
block_size: BlockSize | None = CacheConfig.block_size
- enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching
+ enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo
)
@@ -502,11 +504,6 @@ class EngineArgs:
)
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
reasoning_parser_plugin: str | None = None
- # Deprecated guided decoding fields
- guided_decoding_backend: str | None = None
- guided_decoding_disable_fallback: bool | None = None
- guided_decoding_disable_any_whitespace: bool | None = None
- guided_decoding_disable_additional_properties: bool | None = None
logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
@@ -523,9 +520,6 @@ class EngineArgs:
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
- override_pooler_config: dict | PoolerConfig | None = (
- ModelConfig.override_pooler_config
- )
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
@@ -662,11 +656,6 @@ class EngineArgs:
)
model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
- model_group.add_argument(
- "--override-pooler-config",
- **model_kwargs["override_pooler_config"],
- deprecated=True,
- )
model_group.add_argument(
"--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
)
@@ -725,19 +714,6 @@ class EngineArgs:
"--reasoning-parser-plugin",
**structured_outputs_kwargs["reasoning_parser_plugin"],
)
- # Deprecated guided decoding arguments
- for arg, type in [
- ("--guided-decoding-backend", str),
- ("--guided-decoding-disable-fallback", bool),
- ("--guided-decoding-disable-any-whitespace", bool),
- ("--guided-decoding-disable-additional-properties", bool),
- ]:
- structured_outputs_group.add_argument(
- arg,
- type=type,
- help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
- deprecated=True,
- )
# Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig)
@@ -855,30 +831,6 @@ class EngineArgs:
"--expert-placement-strategy",
**parallel_kwargs["expert_placement_strategy"],
)
- parallel_group.add_argument(
- "--num-redundant-experts",
- type=int,
- help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-window-size",
- type=int,
- help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-step-interval",
- type=int,
- help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-log-balancedness",
- action=argparse.BooleanOptionalAction,
- help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
- deprecated=True,
- )
parallel_group.add_argument(
"--max-parallel-loading-workers",
@@ -920,7 +872,11 @@ class EngineArgs:
"--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
)
cache_group.add_argument(
- "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
+ "--enable-prefix-caching",
+ **{
+ **cache_kwargs["enable_prefix_caching"],
+ "default": None,
+ },
)
cache_group.add_argument(
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
@@ -1184,8 +1140,8 @@ class EngineArgs:
return engine_args
def create_model_config(self) -> ModelConfig:
- # gguf file needs a specific model loader and doesn't use hf_repo
- if check_gguf_file(self.model):
+ # gguf file needs a specific model loader
+ if is_gguf(self.model):
self.quantization = self.load_format = "gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless
@@ -1279,7 +1235,6 @@ class EngineArgs:
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config,
- override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
@@ -1392,11 +1347,10 @@ class EngineArgs:
# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
# Disable chunked prefill and prefix caching for:
- # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
+ # POWER (ppc64le)/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
- CpuArchEnum.ARM,
CpuArchEnum.RISCV,
):
logger.info(
@@ -1613,6 +1567,12 @@ class EngineArgs:
model_config.skip_tokenizer_init = True
logger.info("Skipping tokenizer initialization for tokens-only mode.")
+ if self.async_scheduling and not self.disable_nccl_for_dp_synchronization:
+ logger.info(
+ "Disabling NCCL for DP synchronization when using async scheduling."
+ )
+ self.disable_nccl_for_dp_synchronization = True
+
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
@@ -1737,21 +1697,6 @@ class EngineArgs:
self.reasoning_parser_plugin
)
- # Forward the deprecated CLI args to the StructuredOutputsConfig
- so_config = self.structured_outputs_config
- if self.guided_decoding_backend is not None:
- so_config.guided_decoding_backend = self.guided_decoding_backend
- if self.guided_decoding_disable_fallback is not None:
- so_config.disable_fallback = self.guided_decoding_disable_fallback
- if self.guided_decoding_disable_any_whitespace is not None:
- so_config.disable_any_whitespace = (
- self.guided_decoding_disable_any_whitespace
- )
- if self.guided_decoding_disable_additional_properties is not None:
- so_config.disable_additional_properties = (
- self.guided_decoding_disable_additional_properties
- )
-
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
otlp_traces_endpoint=self.otlp_traces_endpoint,
@@ -1862,9 +1807,11 @@ class EngineArgs:
if model_config.runner_type != "pooling":
default_chunked_prefill = True
- # Disable prefix caching default for hybrid models
- # since the feature is still experimental.
- default_prefix_caching = not model_config.is_hybrid
+ # Disable prefix caching default for hybrid models and mamba-only
+ # models since the feature is still experimental.
+ default_prefix_caching = not (
+ model_config.is_hybrid or model_config.is_attention_free
+ )
else:
assert model_config.pooler_config is not None
@@ -1975,10 +1922,11 @@ class EngineArgs:
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
default_prefix_caching = False
- logger.warning(
+ logger.warning_once(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
- "and prefix caching have been disabled by default."
+ "and prefix caching have been disabled by default.",
+ scope="local",
)
if self.enable_chunked_prefill is None:
@@ -1988,15 +1936,27 @@ class EngineArgs:
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
+ elif (
+ model_config.runner_type == "generate"
+ and not self.enable_chunked_prefill
+ and default_chunked_prefill
+ ):
+ logger.warning_once(
+ "This model does not officially support disabling chunked prefill. "
+ "Disabling this manually may cause the engine to crash "
+ "or produce incorrect outputs.",
+ scope="local",
+ )
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
- logger.warning(
+ logger.warning_once(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
+ scope="local",
)
if self.enable_prefix_caching is None:
@@ -2011,10 +1971,11 @@ class EngineArgs:
and self.enable_prefix_caching
and not default_prefix_caching
):
- logger.warning(
+ logger.warning_once(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
+ scope="local",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
@@ -2077,24 +2038,6 @@ class AsyncEngineArgs(EngineArgs):
enable_log_requests: bool = False
- @property
- @deprecated(
- "`disable_log_requests` is deprecated and has been replaced with "
- "`enable_log_requests`. This will be removed in v0.12.0. Please use "
- "`enable_log_requests` instead."
- )
- def disable_log_requests(self) -> bool:
- return not self.enable_log_requests
-
- @disable_log_requests.setter
- @deprecated(
- "`disable_log_requests` is deprecated and has been replaced with "
- "`enable_log_requests`. This will be removed in v0.12.0. Please use "
- "`enable_log_requests` instead."
- )
- def disable_log_requests(self, value: bool):
- self.enable_log_requests = not value
-
@staticmethod
def add_cli_args(
parser: FlexibleArgumentParser, async_args_only: bool = False
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index aaf8a3ae9d2dd..bf80856c1bbfc 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -1283,6 +1283,7 @@ MM_PARSER_MAP: dict[
"text": lambda part: _TextParser(part).get("text", None),
"thinking": lambda part: _ThinkParser(part).get("thinking", None),
"input_text": lambda part: _TextParser(part).get("text", None),
+ "output_text": lambda part: _TextParser(part).get("text", None),
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
@@ -1463,7 +1464,7 @@ def _parse_chat_message_content_part(
)
return None
- if part_type in ("text", "input_text", "refusal", "thinking"):
+ if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
str_content = cast(str, content)
if wrap_dicts:
return {"type": "text", "text": str_content}
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 848916dbd8763..1860f383d45fb 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -174,9 +174,6 @@ class LLM:
For example, for Phi-3-Vision: `{"num_crops": 4}`.
pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
- override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
- argument is deprecated and will be removed in v0.12.0 or v1.0.0,
- whichever is sooner.
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
@@ -214,7 +211,6 @@ class LLM:
hf_overrides: HfOverrides | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
pooler_config: PoolerConfig | None = None,
- override_pooler_config: PoolerConfig | None = None,
structured_outputs_config: dict[str, Any]
| StructuredOutputsConfig
| None = None,
@@ -330,7 +326,6 @@ class LLM:
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config,
- override_pooler_config=override_pooler_config,
structured_outputs_config=structured_outputs_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index b352c3ad01db0..688ea9697d9d6 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -29,7 +29,6 @@ from openai.types.responses import (
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponsePrompt,
- ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseStatus,
@@ -304,9 +303,7 @@ def get_logits_processors(
return None
-ResponseInputOutputItem: TypeAlias = (
- ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
-)
+ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
class ResponsesRequest(OpenAIBaseModel):
@@ -559,9 +556,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
) = "none"
reasoning_effort: Literal["low", "medium", "high"] | None = None
include_reasoning: bool = True
+ parallel_tool_calls: bool | None = True
- # NOTE this will be ignored by vLLM -- the model determines the behavior
- parallel_tool_calls: bool | None = False
+ # NOTE this will be ignored by vLLM
user: str | None = None
# --8<-- [start:chat-completion-sampling-params]
@@ -652,62 +649,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
- guided_json: str | dict | BaseModel | None = Field(
- default=None,
- description=(
- "`guided_json` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `json` to `structured_outputs` instead."
- ),
- )
- guided_regex: str | None = Field(
- default=None,
- description=(
- "`guided_regex` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `regex` to `structured_outputs` instead."
- ),
- )
- guided_choice: list[str] | None = Field(
- default=None,
- description=(
- "`guided_choice` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `choice` to `structured_outputs` instead."
- ),
- )
- guided_grammar: str | None = Field(
- default=None,
- description=(
- "`guided_grammar` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `grammar` to `structured_outputs` instead."
- ),
- )
- structural_tag: str | None = Field(
- default=None,
- description=(
- "`structural_tag` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `structural_tag` to `structured_outputs` instead."
- ),
- )
- guided_decoding_backend: str | None = Field(
- default=None,
- description=(
- "`guided_decoding_backend` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please remove it from your request."
- ),
- )
- guided_whitespace_pattern: str | None = Field(
- default=None,
- description=(
- "`guided_whitespace_pattern` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `whitespace_pattern` to `structured_outputs` instead."
- ),
- )
priority: int = Field(
default=0,
description=(
@@ -717,7 +658,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -841,20 +782,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
- # Forward deprecated guided_* parameters to structured_outputs
- if self.structured_outputs is None:
- kwargs = dict[str, Any](
- json=self.guided_json,
- regex=self.guided_regex,
- choice=self.guided_choice,
- grammar=self.guided_grammar,
- whitespace_pattern=self.guided_whitespace_pattern,
- structural_tag=self.structural_tag,
- )
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if len(kwargs) > 0:
- self.structured_outputs = StructuredOutputsParams(**kwargs)
-
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
@@ -863,24 +790,23 @@ class ChatCompletionRequest(OpenAIBaseModel):
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
- if response_format is not None:
- if response_format.type == "json_object":
- self.structured_outputs.json_object = True
- elif response_format.type == "json_schema":
- json_schema = response_format.json_schema
- assert json_schema is not None
- self.structured_outputs.json = json_schema.json_schema
- elif response_format.type == "structural_tag":
- structural_tag = response_format
- assert structural_tag is not None and isinstance(
- structural_tag,
- (
- LegacyStructuralTagResponseFormat,
- StructuralTagResponseFormat,
- ),
- )
- s_tag_obj = structural_tag.model_dump(by_alias=True)
- self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
+ if response_format.type == "json_object":
+ self.structured_outputs.json_object = True
+ elif response_format.type == "json_schema":
+ json_schema = response_format.json_schema
+ assert json_schema is not None
+ self.structured_outputs.json = json_schema.json_schema
+ elif response_format.type == "structural_tag":
+ structural_tag = response_format
+ assert structural_tag is not None and isinstance(
+ structural_tag,
+ (
+ LegacyStructuralTagResponseFormat,
+ StructuralTagResponseFormat,
+ ),
+ )
+ s_tag_obj = structural_tag.model_dump(by_alias=True)
+ self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@@ -1140,58 +1066,6 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
- guided_json: str | dict | BaseModel | None = Field(
- default=None,
- description=(
- "`guided_json` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `json` to `structured_outputs` instead."
- ),
- )
- guided_regex: str | None = Field(
- default=None,
- description=(
- "`guided_regex` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `regex` to `structured_outputs` instead."
- ),
- )
- guided_choice: list[str] | None = Field(
- default=None,
- description=(
- "`guided_choice` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `choice` to `structured_outputs` instead."
- ),
- )
- guided_grammar: str | None = Field(
- default=None,
- description=(
- "`guided_grammar` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `grammar` to `structured_outputs` instead."
- ),
- )
- structural_tag: str | None = Field(
- default=None,
- description=("If specified, the output will follow the structural tag schema."),
- )
- guided_decoding_backend: str | None = Field(
- default=None,
- description=(
- "`guided_decoding_backend` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please remove it from your request."
- ),
- )
- guided_whitespace_pattern: str | None = Field(
- default=None,
- description=(
- "`guided_whitespace_pattern` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `whitespace_pattern` to `structured_outputs` instead."
- ),
- )
priority: int = Field(
default=0,
description=(
@@ -1201,7 +1075,7 @@ class CompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -1336,35 +1210,31 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
- guided_json_object = None
- if self.response_format is not None:
- if self.response_format.type == "json_object":
- guided_json_object = True
- elif self.response_format.type == "json_schema":
- json_schema = self.response_format.json_schema
+ response_format = self.response_format
+ if response_format is not None:
+ # If structured outputs wasn't already enabled,
+ # we must enable it for these features to work
+ if self.structured_outputs is None:
+ self.structured_outputs = StructuredOutputsParams()
+
+ # Set structured output params for response format
+ if response_format.type == "json_object":
+ self.structured_outputs.json_object = True
+ elif response_format.type == "json_schema":
+ json_schema = response_format.json_schema
assert json_schema is not None
- self.guided_json = json_schema.json_schema
- elif self.response_format.type == "structural_tag":
- structural_tag = self.response_format
+ self.structured_outputs.json = json_schema.json_schema
+ elif response_format.type == "structural_tag":
+ structural_tag = response_format
assert structural_tag is not None and isinstance(
- structural_tag, StructuralTagResponseFormat
+ structural_tag,
+ (
+ LegacyStructuralTagResponseFormat,
+ StructuralTagResponseFormat,
+ ),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
- self.structural_tag = json.dumps(s_tag_obj)
-
- # Forward deprecated guided_* parameters to structured_outputs
- if self.structured_outputs is None:
- kwargs = dict[str, Any](
- json=self.guided_json,
- json_object=guided_json_object,
- regex=self.guided_regex,
- choice=self.guided_choice,
- grammar=self.guided_grammar,
- whitespace_pattern=self.guided_whitespace_pattern,
- )
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if len(kwargs) > 0:
- self.structured_outputs = StructuredOutputsParams(**kwargs)
+ self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@@ -1502,7 +1372,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -1597,7 +1467,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -2019,7 +1889,7 @@ class ClassificationCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -2110,7 +1980,7 @@ class ClassificationChatRequest(OpenAIBaseModel):
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -3221,7 +3091,7 @@ class TranslationResponseVerbose(OpenAIBaseModel):
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -3278,7 +3148,7 @@ class GenerateResponseChoice(BaseModel):
class GenerateResponse(BaseModel):
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 6cc685acd6728..9a7051e0920af 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
+from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
@@ -273,6 +274,11 @@ class OpenAIServingChat(OpenAIServing):
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i])
+ # If we are creating sub requests for multiple prompts, ensure that they
+ # have unique request ids.
+ sub_request_id = (
+ request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
+ )
if self.default_sampling_params is None:
self.default_sampling_params = {}
@@ -301,7 +307,7 @@ class OpenAIServingChat(OpenAIServing):
)
self._log_inputs(
- request_id,
+ sub_request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
@@ -316,14 +322,14 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
- request_id=request_id,
+ request_id=sub_request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
- request_id,
+ sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
@@ -334,7 +340,7 @@ class OpenAIServingChat(OpenAIServing):
generator = self.engine_client.generate(
engine_request,
sampling_params,
- request_id,
+ sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
@@ -1201,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing):
finish_reason_sent[i] = True
+ choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
@@ -1526,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing):
as_list(output.token_ids) if request.return_token_ids else None
),
)
+ choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
choices.append(choice_data)
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 7dab5dbacd28c..d9feee917ff4e 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -296,11 +296,7 @@ class OpenAIServing:
parser = None
if not enable_auto_tools or tool_parser_name is None:
return parser
- logger.info(
- '"auto" tool choice has been enabled please note that while'
- " the parallel_tool_calls client option is preset for "
- "compatibility reasons, it will be ignored."
- )
+ logger.info('"auto" tool choice has been enabled.')
try:
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
@@ -1242,16 +1238,19 @@ class OpenAIServing:
):
prompt_text, _, _ = self._get_prompt_components(request_prompt)
orig_priority = priority
+ sub_request = 0
while True:
+ # Ensure that each sub-request has a unique request id.
+ sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
- request_id,
+ sub_request_id,
request_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs(
- request_id,
+ sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
@@ -1262,7 +1261,7 @@ class OpenAIServing:
generator = self.engine_client.generate(
engine_request,
sampling_params,
- request_id,
+ sub_request_id,
lora_request=lora_request,
priority=priority,
prompt_text=prompt_text,
@@ -1295,6 +1294,7 @@ class OpenAIServing:
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
# OPTIMIZATION
priority = orig_priority - 1
+ sub_request += 1
def _get_prompt_components(
self,
@@ -1345,11 +1345,12 @@ class OpenAIServing:
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
- default = default or random_uuid()
- if raw_request is None:
- return default
+ if raw_request is not None and (
+ (req_id := raw_request.headers.get("X-Request-Id")) is not None
+ ):
+ return req_id
- return raw_request.headers.get("X-Request-Id", default)
+ return random_uuid() if default is None else default
@staticmethod
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py
index 06efb43ecb7b8..f546dbda7fef5 100644
--- a/vllm/entrypoints/openai/serving_responses.py
+++ b/vllm/entrypoints/openai/serving_responses.py
@@ -94,7 +94,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import (
- construct_chat_message_with_tool_call,
+ construct_input_messages,
convert_tool_responses_to_completions_format,
extract_tool_types,
)
@@ -504,7 +504,12 @@ class OpenAIServingResponses(OpenAIServing):
for tool in request.tools
]
# Construct the input messages.
- messages = self._construct_input_messages(request, prev_response)
+ messages = construct_input_messages(
+ request_instructions=request.instructions,
+ request_input=request.input,
+ prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
+ prev_response_output=prev_response.output if prev_response else None,
+ )
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
@@ -869,47 +874,6 @@ class OpenAIServingResponses(OpenAIServing):
output_items.extend(last_items)
return output_items
- def _construct_input_messages(
- self,
- request: ResponsesRequest,
- prev_response: ResponsesResponse | None = None,
- ) -> list[ChatCompletionMessageParam]:
- messages: list[ChatCompletionMessageParam] = []
- if request.instructions:
- messages.append(
- {
- "role": "system",
- "content": request.instructions,
- }
- )
-
- # Prepend the conversation history.
- if prev_response is not None:
- # Add the previous messages.
- prev_msg = self.msg_store[prev_response.id]
- messages.extend(prev_msg)
-
- # Add the previous output.
- for output_item in prev_response.output:
- # NOTE: We skip the reasoning output.
- if isinstance(output_item, ResponseOutputMessage):
- for content in output_item.content:
- messages.append(
- {
- "role": "assistant",
- "content": content.text,
- }
- )
-
- # Append the new input.
- # Responses API supports simple text inputs without chat format.
- if isinstance(request.input, str):
- messages.append({"role": "user", "content": request.input})
- else:
- for item in request.input:
- messages.append(construct_chat_message_with_tool_call(item))
- return messages
-
def _construct_harmony_system_input_message(
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str]
) -> OpenAIHarmonyMessage:
diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py
index b9b9b1ab30ad8..3dece07748cc4 100644
--- a/vllm/entrypoints/openai/speech_to_text.py
+++ b/vllm/entrypoints/openai/speech_to_text.py
@@ -201,10 +201,10 @@ class OpenAISpeechToText(OpenAIServing):
self.engine_client.generate(
prompt,
sampling_params,
- request_id,
+ f"{request_id}_{i}",
lora_request=lora_request,
)
- for prompt in prompts
+ for i, prompt in enumerate(prompts)
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
index 02fc9b8a4d34e..e1fe6e90dfd0b 100644
--- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
@@ -9,6 +9,7 @@ import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
+import vllm.envs as envs
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser):
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
0
]
- # Updated regex to match multiple JSONs separated by semicolons
- # This pattern is more robust and can handle nested JSON objects
- self.tool_call_regex = re.compile(
- r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
- re.DOTALL,
- )
+ # Simple regex to find opening braces - we'll use JSON decoder for parsing
+ # This handles arbitrary nesting depth correctly
+ self.tool_call_start_regex = re.compile(r"\{")
+ self.json_decoder = json.JSONDecoder()
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
@@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser):
tools_called=False, tool_calls=[], content=model_output
)
- # Find JSON object(s) in the text using regex
- match = self.tool_call_regex.search(model_output)
- if not match:
+ # Keep track of the end index of the last parsed JSON object
+ # so we don't parse inner brackets
+ end_index = -1
+ tool_calls: list[ToolCall] = []
+
+ try:
+ for match in self.tool_call_start_regex.finditer(
+ model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
+ ):
+ start_index = match.start()
+ # Skip if this brace is inside a previously parsed JSON object
+ if start_index <= end_index:
+ continue
+
+ try:
+ obj, json_end_index = self.json_decoder.raw_decode(
+ model_output[start_index:]
+ )
+ end_index = start_index + json_end_index
+
+ # raise KeyError if missing
+ name = obj["name"]
+ arguments_or_params = (
+ obj["arguments"] if "arguments" in obj else obj["parameters"]
+ )
+
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=name,
+ # function call args are JSON but as a string
+ arguments=json.dumps(
+ arguments_or_params, ensure_ascii=False
+ ),
+ ),
+ )
+ )
+ except KeyError as e:
+ # Missing required key
+ missing_key = str(e).strip("'\"")
+ logger.exception(
+ "Couldn't extract tool call from JSON response. "
+ "Required key '%s' not present. "
+ "Returning output in content with empty tool calls.",
+ missing_key,
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ except Exception:
+ # Any other error during parsing
+ logger.exception(
+ "Error in extracting tool call from response. "
+ "Returning output in content with empty tool calls"
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ except TimeoutError:
+ logger.warning("Regex timeout occurred when matching tool call pattern.")
+ logger.debug(
+ "Regex timeout occurred when matching user input: %s", model_output
+ )
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
- try:
- json_str = match.group(0)
- # Split by semicolon and strip whitespace
- json_objects = [obj.strip() for obj in json_str.split(";")]
-
- tool_calls: list[ToolCall] = []
- for json_obj in json_objects:
- if not json_obj: # Skip empty strings
- continue
- obj = json.loads(json_obj)
- tool_calls.append(
- ToolCall(
- type="function",
- function=FunctionCall(
- name=obj["name"],
- # function call args are JSON but as a string
- arguments=json.dumps(
- obj["arguments"]
- if "arguments" in obj
- else obj["parameters"],
- ensure_ascii=False,
- ),
- ),
- )
- )
-
+ # If we have valid tool calls, return them normally
+ if tool_calls:
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)
- except Exception:
- logger.exception("Error in extracting tool call from response.")
- # return information to just treat the tool call as regular JSON
- return ExtractedToolCallInformation(
- tools_called=False, tool_calls=[], content=model_output
- )
+ # No valid tool calls found
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
def extract_tool_calls_streaming(
self,
diff --git a/vllm/entrypoints/openai/utils.py b/vllm/entrypoints/openai/utils.py
new file mode 100644
index 0000000000000..6f37f6adff4c2
--- /dev/null
+++ b/vllm/entrypoints/openai/utils.py
@@ -0,0 +1,37 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import TypeVar
+
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+)
+
+# Used internally
+_ChatCompletionResponseChoiceT = TypeVar(
+ "_ChatCompletionResponseChoiceT",
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+)
+
+
+def maybe_filter_parallel_tool_calls(
+ choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
+) -> _ChatCompletionResponseChoiceT:
+ """Filter to first tool call only when parallel_tool_calls is False."""
+
+ if request.parallel_tool_calls:
+ return choice
+
+ if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
+ choice.message.tool_calls = choice.message.tool_calls[:1]
+ elif (
+ isinstance(choice, ChatCompletionResponseStreamChoice)
+ and choice.delta.tool_calls
+ ):
+ choice.delta.tool_calls = [
+ tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
+ ]
+
+ return choice
diff --git a/vllm/entrypoints/responses_utils.py b/vllm/entrypoints/responses_utils.py
index d966f58804b67..07abb80ebc9e3 100644
--- a/vllm/entrypoints/responses_utils.py
+++ b/vllm/entrypoints/responses_utils.py
@@ -9,7 +9,12 @@ from openai.types.chat import (
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as FunctionCallTool,
)
-from openai.types.responses import ResponseFunctionToolCall
+from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
+from openai.types.responses.response_function_tool_call_output_item import (
+ ResponseFunctionToolCallOutputItem,
+)
+from openai.types.responses.response_output_message import ResponseOutputMessage
+from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
from vllm import envs
@@ -19,6 +24,49 @@ from vllm.entrypoints.openai.protocol import (
)
+def construct_input_messages(
+ *,
+ request_instructions: str | None = None,
+ request_input: str | list[ResponseInputOutputItem],
+ prev_msg: list[ChatCompletionMessageParam] | None = None,
+ prev_response_output: list[ResponseOutputItem] | None = None,
+):
+ messages: list[ChatCompletionMessageParam] = []
+ if request_instructions:
+ messages.append(
+ {
+ "role": "system",
+ "content": request_instructions,
+ }
+ )
+
+ # Prepend the conversation history.
+ if prev_msg is not None:
+ # Add the previous messages.
+ messages.extend(prev_msg)
+ if prev_response_output is not None:
+ # Add the previous output.
+ for output_item in prev_response_output:
+ # NOTE: We skip the reasoning output.
+ if isinstance(output_item, ResponseOutputMessage):
+ for content in output_item.content:
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content.text,
+ }
+ )
+
+ # Append the new input.
+ # Responses API supports simple text inputs without chat format.
+ if isinstance(request_input, str):
+ messages.append({"role": "user", "content": request_input})
+ else:
+ for item in request_input:
+ messages.append(construct_chat_message_with_tool_call(item))
+ return messages
+
+
def construct_chat_message_with_tool_call(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
@@ -37,6 +85,24 @@ def construct_chat_message_with_tool_call(
)
],
)
+ elif isinstance(item, ResponseReasoningItem):
+ reasoning_content = ""
+ if item.encrypted_content:
+ raise ValueError("Encrypted content is not supported.")
+ if len(item.summary) == 1:
+ reasoning_content = item.summary[0].text
+ elif item.content and len(item.content) == 1:
+ reasoning_content = item.content[0].text
+ return {
+ "role": "assistant",
+ "reasoning": reasoning_content,
+ }
+ elif isinstance(item, ResponseFunctionToolCallOutputItem):
+ return ChatCompletionToolMessageParam(
+ role="tool",
+ content=item.output,
+ tool_call_id=item.call_id,
+ )
elif item.get("type") == "function_call_output":
# Append the function call output as a tool message.
return ChatCompletionToolMessageParam(
diff --git a/vllm/envs.py b/vllm/envs.py
index 9b1ed1fc680b4..56558548d3981 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -640,7 +640,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Example options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
- # - "XFORMERS": use XFormers
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index 25fb7181a8f29..7cb490e391abb 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -153,7 +153,7 @@ class DPMetadata:
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
- Context mamager for setting self.local_sizes. Same as self.chunked_sizes
+ Context manager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
diff --git a/vllm/logger.py b/vllm/logger.py
index 772e36497b45e..ad3123c0f0149 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -7,7 +7,8 @@ import json
import logging
import os
import sys
-from collections.abc import Hashable
+from collections.abc import Generator, Hashable
+from contextlib import contextmanager
from functools import lru_cache, partial
from logging import Logger
from logging.config import dictConfig
@@ -212,6 +213,14 @@ def init_logger(name: str) -> _VllmLogger:
return cast(_VllmLogger, logger)
+@contextmanager
+def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
+ current_level = logging.root.manager.disable
+ logging.disable(level)
+ yield
+ logging.disable(current_level)
+
+
# The root logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py
index 8a4f5ff175d4f..25364a5881364 100644
--- a/vllm/lora/layers/__init__.py
+++ b/vllm/lora/layers/__init__.py
@@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
-from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
+from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
@@ -38,4 +38,5 @@ __all__ = [
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
+ "FusedMoE3DWithLoRA",
]
diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py
index 62326c05b2bd1..3bfb88c007622 100644
--- a/vllm/lora/layers/base.py
+++ b/vllm/lora/layers/base.py
@@ -42,8 +42,8 @@ class BaseLayerWithLoRA(nn.Module):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py
index e85c5bd70b072..06ecc8d2f634c 100644
--- a/vllm/lora/layers/base_linear.py
+++ b/vllm/lora/layers/base_linear.py
@@ -94,13 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
+ assert isinstance(lora_a, torch.Tensor)
+ assert isinstance(lora_b, torch.Tensor)
assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
)
diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py
index 273c4950e3239..3e21d426c304a 100644
--- a/vllm/lora/layers/column_parallel_linear.py
+++ b/vllm/lora/layers/column_parallel_linear.py
@@ -246,8 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
self.reset_lora(index)
diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py
index adf30855cafc3..0eb6562bec6cd 100644
--- a/vllm/lora/layers/fused_moe.py
+++ b/vllm/lora/layers/fused_moe.py
@@ -42,6 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
+ self._w13_slices = 2
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
@@ -60,8 +61,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def _get_lora_moe_configs(
self,
op_prefix: str,
- lora_a_stacked: torch.Tensor,
- lora_b_stacked: torch.Tensor,
+ num_loras: int,
+ rank: int,
num_slices: int,
M: int,
layer: FusedMoE,
@@ -69,23 +70,25 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
+ hidden_size = layer.hidden_size
+ intermediate_size = layer.intermediate_size_per_partition
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
- max_loras=lora_a_stacked.shape[0],
+ max_loras=num_loras,
batch=M,
- hidden_size=lora_a_stacked.shape[-1],
- rank=lora_a_stacked.shape[-2],
+ hidden_size=hidden_size,
+ rank=rank,
num_slices=num_slices,
- moe_intermediate_size=lora_b_stacked.shape[-2],
+ moe_intermediate_size=intermediate_size,
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
- max_loras=lora_a_stacked.shape[0],
+ max_loras=num_loras,
batch=M,
- hidden_size=lora_a_stacked.shape[-1],
- rank=lora_a_stacked.shape[-2],
+ hidden_size=hidden_size, # lora_a_stacked.shape[-1],
+ rank=rank,
num_slices=num_slices,
- moe_intermediate_size=lora_b_stacked.shape[-2],
+ moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
@@ -152,12 +155,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
-
+ max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
- lora_a_stacked=self.w1_lora_a_stacked,
- lora_b_stacked=self.w1_lora_b_stacked,
- num_slices=2,
+ num_loras=self.max_loras,
+ rank=max_lora_rank,
+ num_slices=self._w13_slices,
M=M,
layer=layer,
top_k=top_k,
@@ -165,7 +168,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
# get the block size of m from customized config or default config
- max_loras = self.w1_lora_a_stacked.shape[0]
(
sorted_token_ids_lora,
expert_ids_lora,
@@ -175,7 +177,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens,
shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
- max_loras,
+ self.max_loras,
self.adapter_enabled,
expert_map,
)
@@ -186,17 +188,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora
)
- w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
- w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
- max_lora_rank = self.w1_lora_a_stacked.shape[-2]
- expert_ids_lora = expert_ids_lora.view(max_loras, -1)
- sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
+ expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
+ sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
+ #
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
- w13_lora_a_stacked,
- w13_lora_b_stacked,
+ self.w13_lora_a_stacked,
+ self.w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
@@ -230,11 +230,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
-
+ max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
- lora_a_stacked=self.w2_lora_a_stacked,
- lora_b_stacked=self.w2_lora_b_stacked,
+ num_loras=self.max_loras,
+ rank=max_lora_rank,
num_slices=1,
M=M,
layer=layer,
@@ -247,20 +247,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
- max_loras = self.w1_lora_a_stacked.shape[0]
- expert_ids_lora = expert_ids_lora.view(max_loras, -1)
- sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
+
+ expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
+ sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
- max_lora_rank = self.w2_lora_a_stacked.shape[-2]
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
- [self.w2_lora_a_stacked],
- [self.w2_lora_b_stacked],
+ self.w2_lora_a_stacked,
+ self.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
@@ -289,11 +288,72 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
-
self.base_layer.quant_method = FusedMoEModularMethod(
self.base_layer.quant_method, m_fused_moe_fn
)
+ def _create_lora_a_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ ):
+ self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ lora_config.max_lora_rank
+ if not self.fully_sharded
+ else divide(lora_config.max_lora_rank, self.tp_size),
+ self.base_layer.hidden_size,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ for _ in range(self._w13_slices)
+ )
+ self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ lora_config.max_lora_rank,
+ self.base_layer.intermediate_size_per_partition,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ),
+ )
+
+ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
+ self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.intermediate_size_per_partition,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ for _ in range(self._w13_slices)
+ )
+ self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.hidden_size
+ if not self.fully_sharded
+ else divide(self.base_layer.hidden_size, self.tp_size),
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ),
+ )
+
def create_lora_weights(
self,
max_loras: int,
@@ -301,113 +361,63 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
+ self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
- self.w1_lora_a_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- lora_config.max_lora_rank
- if not self.fully_sharded
- else divide(lora_config.max_lora_rank, self.tp_size),
- self.base_layer.hidden_size,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
- self.w1_lora_b_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- self.base_layer.intermediate_size_per_partition,
- lora_config.max_lora_rank,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
-
- self.w2_lora_a_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- lora_config.max_lora_rank,
- self.base_layer.intermediate_size_per_partition,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
- self.w2_lora_b_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- self.base_layer.hidden_size
- if not self.fully_sharded
- else divide(self.base_layer.hidden_size, self.tp_size),
- lora_config.max_lora_rank,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
-
- self.w3_lora_a_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- lora_config.max_lora_rank
- if not self.fully_sharded
- else divide(lora_config.max_lora_rank, self.tp_size),
- self.base_layer.hidden_size,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
- self.w3_lora_b_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- self.base_layer.intermediate_size_per_partition,
- lora_config.max_lora_rank,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
-
+ self._create_lora_a_weights(max_loras, lora_config)
+ self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
+ # TODO Optimize this section
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj
- self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
- self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
- self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
+ self.lora_a_stacked.append(
+ self.w13_lora_a_stacked[0][lora_id][experts_id]
+ )
+ self.lora_a_stacked.append(
+ self.w2_lora_a_stacked[0][lora_id][experts_id]
+ )
- self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
- self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
- self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
+ self.lora_b_stacked.append(
+ self.w13_lora_b_stacked[0][lora_id][experts_id]
+ )
+ self.lora_b_stacked.append(
+ self.w2_lora_b_stacked[0][lora_id][experts_id]
+ )
+
+ self.lora_a_stacked.append(
+ self.w13_lora_a_stacked[1][lora_id][experts_id]
+ )
+ self.lora_b_stacked.append(
+ self.w13_lora_b_stacked[1][lora_id][experts_id]
+ )
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
- self.w1_lora_a_stacked[index] = 0
- self.w1_lora_b_stacked[index] = 0
- self.w3_lora_a_stacked[index] = 0
- self.w3_lora_b_stacked[index] = 0
- self.w2_lora_a_stacked[index] = 0
- self.w2_lora_b_stacked[index] = 0
+ for pos in range(self._w13_slices):
+ self.w13_lora_a_stacked[pos][index] = 0
+ self.w13_lora_b_stacked[pos][index] = 0
+
+ self.w2_lora_a_stacked[0][index] = 0
+ self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
+ assert isinstance(lora_a, list)
+ assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
@@ -434,50 +444,41 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if self.fully_sharded:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim.
- w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
+ w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
- w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
+ w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
-
- self.w1_lora_a_stacked[
+ # w1 lora_a
+ self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
-
- self.w3_lora_a_stacked[
+ # w3 lora_a
+ self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
- self.w2_lora_b_stacked[
- index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
- ].copy_(w2_lora_b, non_blocking=True)
-
- self.w1_lora_b_stacked[
+ # w1 lora_b
+ self.w13_lora_b_stacked[0][
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
- self.w3_lora_b_stacked[
+ # w3 lora_b
+ self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
- self.w2_lora_a_stacked[
+
+ self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
- @classmethod
- def can_replace_layer(
- cls,
- source_layer: nn.Module,
- lora_config: LoRAConfig,
- packed_modules_list: list,
- model_config: PretrainedConfig | None,
- ) -> bool:
- """Returns True if the layer can be replaced by this LoRA layer."""
- # return type(source_layer) is FusedMoE
- return isinstance(source_layer, FusedMoE)
+ self.w2_lora_b_stacked[0][
+ index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
+ ].copy_(w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
@@ -496,3 +497,220 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: PretrainedConfig | None,
+ ) -> bool:
+ """Returns True if the layer can be replaced by this LoRA layer."""
+ # return type(source_layer) is FusedMoE
+
+ return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
+
+
+class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
+ def __init__(self, base_layer):
+ super().__init__(base_layer)
+ self._w13_slices = 1
+
+ def _create_lora_b_weights(self, max_loras, lora_config):
+ self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.intermediate_size_per_partition * 2,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ for _ in range(self._w13_slices)
+ )
+ self.w2_lora_b_stacked: tuple[torch.Tensor] = (
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.hidden_size
+ if not self.fully_sharded
+ else divide(self.base_layer.hidden_size, self.tp_size),
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ),
+ )
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: PretrainedConfig | None = None,
+ ) -> None:
+ """Initializes lora matrices."""
+ self.max_loras = lora_config.max_loras
+ self.fully_sharded = lora_config.fully_sharded_loras
+
+ self.adapter_enabled = torch.tensor(
+ [0] * (max_loras + 1), dtype=torch.int, device=self.device
+ )
+
+ self._create_lora_a_weights(max_loras, lora_config)
+ self._create_lora_b_weights(max_loras, lora_config)
+
+ def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
+ if self.tp_size == 1 or not self.fully_sharded:
+ return w13_lora_a
+
+ # w13_lora_a shape (num_experts,rank,input_size)
+ current_lora_rank = w13_lora_a.shape[1]
+ assert current_lora_rank % self.tp_size == 0
+
+ sliced_rank = current_lora_rank // self.tp_size
+ start_idx = self.tp_rank * sliced_rank
+ end_idx = (self.tp_rank + 1) * sliced_rank
+ return w13_lora_a[:, start_idx:end_idx, :]
+
+ def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
+ if self.tp_size == 1:
+ return w13_lora_b
+
+ # w13_lora_b shape (num_experts,output_size,rank)
+ shard_size = self.base_layer.intermediate_size_per_partition
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+ if is_interleave:
+ # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
+ # in the interleaved order, and corresponding LoRA need to be processed.
+ w1_lora_b = w13_lora_b[:, ::2, :]
+ w3_lora_b = w13_lora_b[:, 1::2, :]
+ sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
+ sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
+
+ return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
+ 1, 2
+ )
+ else:
+ slice_size = w13_lora_b.shape[1] // 2
+ w1_lora_b = w13_lora_b[:, :slice_size, :]
+ w3_lora_b = w13_lora_b[:, slice_size:, :]
+ sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
+ sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
+
+ return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
+
+ def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
+ if self.tp_size == 1:
+ return w2_lora_a
+ # w2_lora_a shape (num_experts,rank,input_size)
+ shard_size = self.base_layer.intermediate_size_per_partition
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+
+ return w2_lora_a[:, :, start_idx:end_idx]
+
+ def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
+ if self.tp_size == 1 or not self.fully_sharded:
+ return w2_lora_b
+ # Based on S-LoRA, we slice W2 B along the hidden_size dim.
+ # w2_lora_b shape (num_experts,output_size,rank)
+ current_lora_size = w2_lora_b.shape[1]
+
+ sliced_size = current_lora_size // self.tp_size
+ start_idx = self.tp_rank * sliced_size
+ end_idx = (self.tp_rank + 1) * sliced_size
+ return w2_lora_b[:, start_idx:end_idx, :]
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
+ ):
+ """Overwrites lora tensors at index."""
+ # Make mypy happy
+ assert isinstance(lora_a, list)
+ assert isinstance(lora_b, list)
+ assert len(lora_a) == len(lora_b) == 2
+
+ self.reset_lora(index)
+ self.adapter_enabled[index] = 1
+
+ num_experts = self.w13_lora_a_stacked[0].shape[1]
+ w13_lora_a, w2_lora_a = lora_a
+ w13_lora_b, w2_lora_b = lora_b
+
+ # (num_experts,rank,input_size)
+ w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
+ w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
+ # (output_size,num_experts,rank)
+ w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
+ w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
+ # (num_experts,output_size,rank)
+ w13_lora_b = w13_lora_b.permute(1, 0, 2)
+ w2_lora_b = w2_lora_b.permute(1, 0, 2)
+
+ sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
+ sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
+
+ sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
+ sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
+
+ self.w13_lora_a_stacked[0][
+ index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
+ ].copy_(sliced_w13_lora_a, non_blocking=True)
+ self.w2_lora_a_stacked[0][
+ index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
+ ].copy_(sliced_w2_lora_a, non_blocking=True)
+
+ self.w13_lora_b_stacked[0][
+ index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
+ ].copy_(sliced_w13_lora_b, non_blocking=True)
+ self.w2_lora_b_stacked[0][
+ index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
+ ].copy_(sliced_w2_lora_b, non_blocking=True)
+
+ @property
+ def w13_input_size(self):
+ """
+ Full size
+ """
+ return self.w13_lora_a_stacked[0].shape[-1]
+
+ @property
+ def w13_output_size(self):
+ """
+ Full size
+ """
+ return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
+
+ @property
+ def w2_input_size(self):
+ """
+ Full size
+ """
+ return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
+
+ @property
+ def w2_output_size(self):
+ """
+ Full size
+ """
+ return self.w2_lora_a_stacked[0].shape[-2]
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: PretrainedConfig | None,
+ ) -> bool:
+ """Returns True if the layer can be replaced by this LoRA layer."""
+
+ return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py
index 06f92652031e1..c01984db4e64c 100644
--- a/vllm/lora/layers/logits_processor.py
+++ b/vllm/lora/layers/logits_processor.py
@@ -128,9 +128,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
+ assert isinstance(lora_a, torch.Tensor)
+ assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py
index 5b1f7886bc238..c87ca9e24dece 100644
--- a/vllm/lora/layers/vocal_parallel_embedding.py
+++ b/vllm/lora/layers/vocal_parallel_embedding.py
@@ -77,12 +77,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
+ assert isinstance(lora_a, torch.Tensor)
+ assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
+
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True
)
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index eb11cd0afc487..636f062feb7b0 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -22,11 +22,13 @@ from vllm.lora.utils import (
from_layer_logits_processor,
get_supported_lora_modules,
is_base_embeddding_weights,
+ is_moe_model,
is_regex_target_modules,
parse_fine_tuned_lora_name,
process_packed_modules_mapping,
replace_submodule,
)
+from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
@@ -356,7 +358,11 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
+ self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
+ self.model, "is_3d_moe_weight"
+ )
self._create_lora_modules()
+
self.model.lora_manager = self
def __len__(self) -> int:
@@ -400,22 +406,36 @@ class LoRAModelManager:
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
- if module_lora:
- # Note (gnovack) - If MOE lora weights are not split into
- # num_experts chunks, we split them here
- if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
- module_lora.lora_a
- ):
- # Handle FSDP file format where experts.base_layer is the
- # gate_up_proj and experts is the down_proj
- gate_up_proj_lora = self._get_lora_layer_weights(
- lora_model, module_name + ".base_layer"
- )
-
- assert gate_up_proj_lora is not None
- assert module_lora is not None
-
- down_proj_lora = module_lora
+ if not module_lora:
+ module.reset_lora(index)
+ continue
+ # Note (gnovack) - If MOE lora weights are not split into
+ # num_experts chunks, we split them here
+ if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
+ module_lora.lora_a
+ ):
+ # Handle PEFT file format where experts.base_layer is the
+ # gate_up_proj and experts is the down_proj
+ gate_up_proj_lora = self._get_lora_layer_weights(
+ lora_model, module_name + ".base_layer"
+ )
+ down_proj_lora = module_lora
+ # FIXME Edge case where LoRA is not added to gate_up_proj
+ # or down_proj
+ assert gate_up_proj_lora is not None
+ assert down_proj_lora is not None
+ if self._is_3d_moe_model:
+ module_lora.lora_a = [
+ gate_up_proj_lora.lora_a,
+ down_proj_lora.lora_a,
+ ]
+ module_lora.lora_b = [
+ gate_up_proj_lora.lora_b,
+ down_proj_lora.lora_b,
+ ]
+ else:
+ # Some 3D MoE models haven't added the `is_3d_moe_weight`
+ # attribute yet, so fallback here
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
@@ -444,14 +464,12 @@ class LoRAModelManager:
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
+ module.set_lora(
+ index,
+ module_lora.lora_a,
+ module_lora.lora_b,
+ )
- module.set_lora(
- index,
- module_lora.lora_a,
- module_lora.lora_b,
- )
- else:
- module.reset_lora(index)
return True
def _deactivate_adapter(self, lora_id: int):
@@ -512,6 +530,13 @@ class LoRAModelManager:
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
+ if isinstance(module, FusedMoE):
+ # packed_moduled_lst is used here to just determine whether to
+ # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
+ # difference between these two LoRA layers is whether the
+ # LoRA weights of w1 and w3 have already been fused on disk.
+
+ packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
new_module = replace_submodule(
self.model,
module_name,
@@ -560,6 +585,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
+ pass
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
@@ -605,6 +631,30 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype,
"cpu",
)
+ model.loras[module_name] = lora
+ elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
+ # Case for 3D moe model
+ # w2
+ lora = LoRALayerWeights.create_dummy_lora_weights(
+ module_name,
+ module.w2_input_size,
+ module.w2_output_size,
+ rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts
+ module.w2_lora_a_stacked[0].dtype,
+ "cpu",
+ )
+ model.loras[module_name] = lora
+ # w13
+ lora = LoRALayerWeights.create_dummy_lora_weights(
+ module_name,
+ module.w13_input_size,
+ module.w13_output_size,
+ rank
+ * module.w13_lora_a_stacked[0].shape[1], # rank*num_experts
+ module.w13_lora_a_stacked[0].dtype,
+ "cpu",
+ )
+ model.loras[module_name + ".base_layer"] = lora
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
@@ -614,6 +664,7 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype,
"cpu",
)
+ model.loras[module_name] = lora
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
@@ -629,7 +680,7 @@ class LoRAModelManager:
)
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
- model.loras[module_name] = lora
+ model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py
index 7c0fc8167711d..ce38751e4b6a7 100644
--- a/vllm/lora/punica_wrapper/punica_base.py
+++ b/vllm/lora/punica_wrapper/punica_base.py
@@ -470,8 +470,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self,
y: torch.Tensor,
x: torch.Tensor,
- lora_a_stacked: list[torch.Tensor],
- lora_b_stacked: list[torch.Tensor],
+ lora_a_stacked: tuple[torch.Tensor, ...],
+ lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py
index 52138ef0cc3b0..ef4b4ab7c3497 100644
--- a/vllm/lora/punica_wrapper/punica_gpu.py
+++ b/vllm/lora/punica_wrapper/punica_gpu.py
@@ -360,8 +360,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self,
y: torch.Tensor,
x: torch.Tensor,
- lora_a_stacked: list[torch.Tensor],
- lora_b_stacked: list[torch.Tensor],
+ lora_a_stacked: tuple[torch.Tensor, ...],
+ lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py
index a49a7d9d1669d..12524994d4968 100644
--- a/vllm/lora/utils.py
+++ b/vllm/lora/utils.py
@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
+ FusedMoE3DWithLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
@@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
+ FusedMoE3DWithLoRA,
}
@@ -288,10 +290,12 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
-
- packed_modules_mapping["experts"] = [
- weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
- ]
+ if not hasattr(model, "is_3d_moe_weight"):
+ # 3D MoE LoRA does not need `packed_modules_mapping`
+ packed_modules_mapping["experts"] = [
+ weight_name.rstrip(".")
+ for _, weight_name, _, _ in moe_packed_mapping
+ ]
return packed_modules_mapping
else:
diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py
index bec7af0286345..be7f673e5618f 100644
--- a/vllm/model_executor/layers/batch_invariant.py
+++ b/vllm/model_executor/layers/batch_invariant.py
@@ -805,26 +805,26 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
- "TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
+ # "TRITON_MLA",
]
if curr_attn_backend not in supported_backends:
- warning = (
- "Forcibly updating attention backend to"
- f" {supported_backends[0]} for batch_invariant. "
- f" Supported backends: {supported_backends}."
+ error = (
+ "VLLM batch_invariant mode requires an attention backend in "
+ f"{supported_backends}, but got '{curr_attn_backend}'. "
+ "Please set the 'VLLM_ATTENTION_BACKEND' environment variable "
+ "to one of the supported backends before enabling batch_invariant."
)
- logger.warning_once(warning)
- os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
+ raise RuntimeError(error)
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
- logger.warning_once(warning)
+ logger.warning_once(warning, scope="local")
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 21eb4d590a7d1..1826fafa8c4f5 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -28,10 +28,11 @@ logger = init_logger(__name__)
if has_triton_kernels():
try:
from triton_kernels.matmul_ogs import PrecisionConfig
- except ImportError:
+ except (ImportError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
- "version is compatible."
+ "version is compatible. Error: %s",
+ e,
)
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
new file mode 100644
index 0000000000000..54fe5374cb95d
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
@@ -0,0 +1,147 @@
+{
+ "triton_version": "3.5.0",
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json
new file mode 100644
index 0000000000000..8b78f87e7f73b
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json
@@ -0,0 +1,147 @@
+{
+ "triton_version": "3.5.0",
+ "1": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 2
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
index 572307052b489..659a2d4ee5b39 100644
--- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
@@ -6,22 +6,7 @@ import torch
from torch.nn import functional as F
from vllm import _custom_ops as ops
-
-
-def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
- d = x.shape[-1] // 2
- return F.silu(x[..., :d]) * x[..., d:]
-
-
-def swigluoai_and_mul(
- x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
-) -> torch.Tensor:
- d = x.shape[-1] // 2
- gate, up = x[..., :d], x[..., d:]
- gate = gate.clamp(max=limit)
- up = up.clamp(min=-limit, max=limit)
- glu = gate * torch.sigmoid(alpha * gate)
- return (up + 1) * glu
+from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
def grouped_topk(
@@ -227,6 +212,11 @@ class CPUFusedMOE:
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
+ self.act_to_impl = {
+ "silu": SiluAndMul(),
+ "swigluoai": SwigluOAIAndMul(),
+ }
+
def __call__(
self,
layer: torch.nn.Module,
@@ -246,7 +236,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
- assert activation in {"silu", "swigluoai"}, f"{activation} is not supported."
+ assert activation in self.act_to_impl, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -283,10 +273,7 @@ class CPUFusedMOE:
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
- if activation == "swigluoai":
- gate_up = swigluoai_and_mul(gate_up)
- else:
- gate_up = silu_and_mul(gate_up)
+ gate_up = self.act_to_impl[activation].forward_native(gate_up)
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index 073e90a4e6808..ef7090c349fc6 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def allow_inplace(self) -> bool:
return False
+ @property
+ def method_name(self) -> str:
+ return self.__class__.__name__
+
@abstractmethod
def apply(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
index c6dc95acdb636..c23c41df226f0 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
@@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
+ @property
+ def method_name(self) -> str:
+ return self.old_quant_method.method_name
+
def create_weights(
self,
layer: torch.nn.Module,
@@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- # Is getattr needed?
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
- if enable_eplb:
- if self.supports_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
- else:
- raise NotImplementedError(
- "EPLB is not supported for "
- f"{self.old_quant_method.__class__.__name__}."
- )
-
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
)
result = self.fused_experts(
@@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
expert_map=None if self.disable_expert_map else expert_map,
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index b2f554efd8a6f..bb30f1292a5fa 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -1391,7 +1391,48 @@ class FusedMoE(CustomOp):
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]:
+ def _maybe_make_contiguous(
+ name: str, p: torch.nn.Parameter
+ ) -> torch.nn.Parameter:
+ """
+ In some cases, the last 2 dimensions (the non-expert dimensions)
+ of the weight scale tensor are transposed. This function
+ transforms the tensor (view update) so the tensor is contiguous().
+ Example: A non-contiguous scale tensor,
+ `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to
+ `x_` of shape (E, 16, 32) and stride (512, 32, 1).
+ Note that we specifically use torch.transpose() so `x_` refers
+ to the same underlying memory. The tensors `x` and `x_`, pointing
+ to the same underlying memory make this transformation safe in the
+ context of EPLB. i.e. It is the same memory and just the view
+ is different.
+ Note: This function handles the "weight_scale" tensors specifically.
+ This could however be generalized to handle similar tensors.
+ """
+ if p.ndim != 3:
+ return p
+ if p.is_contiguous():
+ # Already contiguous. do nothing.
+ return p
+ # p is non-contiguous. We only handle the case where the last 2
+ # dimensions of the scales tensor is transposed. We can handle
+ # other cases when they become relevant.
+ is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1
+ if "weight_scale" not in name or not is_transposed_12:
+ # do nothing.
+ return p
+
+ # Do not update the layer paramater as the layer's MoE operations would
+ # expect the parameter's tensor to the same shape / stride. Instead,
+ # make a new torch.nn.Parameter that is used just in the context of
+ # EPLB.
+ return torch.nn.Parameter(
+ torch.transpose(p.data, 1, 2), requires_grad=False
+ )
+
weights = list(self.named_parameters())
+ weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
+
assert all(
weight.is_contiguous()
for name, weight in weights
@@ -1469,30 +1510,11 @@ class FusedMoE(CustomOp):
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
- @staticmethod
def select_experts(
+ self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
- top_k: int,
- use_grouped_topk: bool,
- renormalize: bool,
- topk_group: int | None = None,
- num_expert_group: int | None = None,
- custom_routing_function: Callable | None = None,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0,
- e_score_correction_bias: torch.Tensor | None = None,
- indices_type: torch.dtype | None = None,
- enable_eplb: bool = False,
- expert_map: torch.Tensor | None = None,
- expert_load_view: torch.Tensor | None = None,
- logical_to_physical_map: torch.Tensor | None = None,
- logical_replica_count: torch.Tensor | None = None,
- global_num_experts: int | None = None,
- zero_expert_num: int | None = None,
- zero_expert_type: str | None = None,
- num_fused_shared_experts: int = 0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
@@ -1511,6 +1533,27 @@ class FusedMoE(CustomOp):
fused_topk_bias,
)
+ if self.enable_eplb:
+ if self.quant_method.supports_eplb:
+ if self.expert_load_view is None:
+ raise ValueError(
+ "enable_eplb=True requiere expert_load_view != None"
+ )
+ if self.logical_to_physical_map is None:
+ raise ValueError(
+ "enable_eplb=True requiere logical_to_physical_map != None"
+ )
+ if self.logical_replica_count is None:
+ raise ValueError(
+ "enable_eplb=True requiere logical_replica_count != None"
+ )
+ else:
+ raise NotImplementedError(
+ f"EPLB is not supported for {self.quant_method.method_name}."
+ )
+
+ indices_type = self.quant_method.topk_indices_dtype
+
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
@@ -1518,20 +1561,20 @@ class FusedMoE(CustomOp):
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
- top_k=top_k,
+ top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
- elif use_grouped_topk:
- assert topk_group is not None
- assert num_expert_group is not None
+ elif self.use_grouped_topk:
+ assert self.topk_group is not None
+ assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
- assert num_fused_shared_experts == 0
+ assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
- num_fused_shared_experts=num_fused_shared_experts,
+ num_fused_shared_experts=self.num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
@@ -1539,50 +1582,46 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
+ topk=self.top_k,
+ renormalize=self.renormalize,
+ num_expert_group=self.num_expert_group,
+ topk_group=self.topk_group,
+ scoring_func=self.scoring_func,
+ routed_scaling_factor=self.routed_scaling_factor,
+ e_score_correction_bias=self.e_score_correction_bias,
)
- elif e_score_correction_bias is not None:
+ elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
- e_score_correction_bias=e_score_correction_bias.data,
- topk=top_k,
- renormalize=renormalize,
+ e_score_correction_bias=self.e_score_correction_bias.data,
+ topk=self.top_k,
+ renormalize=self.renormalize,
)
- if routed_scaling_factor != 1.0:
- topk_weights *= routed_scaling_factor
- elif custom_routing_function is None:
+ if self.routed_scaling_factor != 1.0:
+ topk_weights *= self.routed_scaling_factor
+ elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
+ topk=self.top_k,
+ renormalize=self.renormalize,
indices_type=indices_type,
)
else:
- topk_weights, topk_ids = custom_routing_function(
+ topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
+ topk=self.top_k,
+ renormalize=self.renormalize,
)
- if enable_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
-
+ if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
+ expert_load_view=self.expert_load_view,
+ logical_to_physical_map=self.logical_to_physical_map,
+ logical_replica_count=self.logical_replica_count,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
@@ -1592,16 +1631,16 @@ class FusedMoE(CustomOp):
# Compute zero expert result if needed
if (
- zero_expert_num is not None
- and zero_expert_num > 0
- and zero_expert_type is not None
- and global_num_experts is not None
+ self.zero_expert_num is not None
+ and self.zero_expert_num > 0
+ and self.zero_expert_type is not None
+ and self.global_num_experts is not None
):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
- num_experts=global_num_experts,
- zero_expert_type=zero_expert_type,
+ num_experts=self.global_num_experts,
+ zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states,
)
else:
@@ -1651,6 +1690,10 @@ class FusedMoE(CustomOp):
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
+ # Slice before all_reduce to enable possible fusion
+ if self.hidden_size != og_hidden_states:
+ states = states[..., :og_hidden_states]
+
if (
not self.is_sequence_parallel
and not self.use_dp_chunking
@@ -1673,11 +1716,12 @@ class FusedMoE(CustomOp):
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(fused_output, tuple)
fused_output, zero_expert_result = fused_output
- return (reduce_output(fused_output) + zero_expert_result)[
- ..., :og_hidden_states
- ]
+ return (
+ reduce_output(fused_output)
+ + zero_expert_result[..., :og_hidden_states]
+ )
else:
- return reduce_output(fused_output)[..., :og_hidden_states]
+ return reduce_output(fused_output)
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -1690,8 +1734,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, self.layer_name
)
return (
- reduce_output(shared_output)[..., :og_hidden_states],
- reduce_output(fused_output)[..., :og_hidden_states],
+ reduce_output(shared_output),
+ reduce_output(fused_output),
)
def forward_cuda(
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index 63b0e6f573d65..48e5a8907f926 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
- num_fused_shared_experts=layer.num_fused_shared_experts,
)
if self.rocm_aiter_moe_enabled:
@@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
@@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 3f6ea68072b40..66945e2d2a7c8 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
-
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py
index e5a741e639ad9..1e57fa218b797 100644
--- a/vllm/model_executor/layers/quantization/bitsandbytes.py
+++ b/vllm/model_executor/layers/quantization/bitsandbytes.py
@@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `BitsAndBytesMoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
+ # TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index ad547dd409822..71d7de97d4a10 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -103,7 +103,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
- "CompressedTensorsW4A4MoeMethod",
+ "CompressedTensorsW4A4Nvfp4MoeMethod",
"CompressedTensorsW4A8Int8MoEMethod",
]
@@ -171,7 +171,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
quant_config, layer.moe_config
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
- return CompressedTensorsW4A4MoeMethod(layer.moe_config)
+ return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
@@ -188,7 +188,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
)
-class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
+class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
@@ -205,8 +205,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
- " for CompressedTensorsW4A4MoeMethod."
+ " for CompressedTensorsW4A4Nvfp4MoeMethod."
)
+ elif self.use_marlin:
+ logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
+ else:
+ logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
def create_weights(
self,
@@ -511,7 +515,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -532,16 +536,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
- )
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
+ )
+
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@@ -554,19 +559,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=e_score_correction_bias,
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@@ -621,7 +616,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
- "CompressedTensorsW4A4MoeMethod."
+ "CompressedTensorsW4A4Nvfp4MoeMethod."
)
assert self.moe_quant_config is not None
@@ -1109,7 +1104,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1130,31 +1125,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
- assert isinstance(layer, FusedMoE)
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- num_fused_shared_experts=layer.num_fused_shared_experts,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
@@ -1377,7 +1350,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1398,26 +1371,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
@@ -1738,7 +1696,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1759,26 +1717,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
- )
-
assert activation == "silu", f"{activation} not supported for Marlin MoE."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
@@ -2001,7 +1944,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -2022,43 +1965,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- if expert_load_view is None:
- raise ValueError("enable_eplb=True requiere expert_load_view != None")
- if logical_to_physical_map is None:
- raise ValueError(
- "enable_eplb=True requiere logical_to_physical_map != None"
- )
- if logical_replica_count is None:
- raise ValueError(
- "enable_eplb=True requiere logical_replica_count != None"
- )
- if not isinstance(layer, FusedMoE):
- raise TypeError(
- "EPLB is only supported when `layer` is a instance of FusedMoE."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index 5241f9a2301be..7ebe40ec84687 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ExpertsInt8MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 91bd45bf879cb..e033032903e87 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
@@ -118,7 +119,9 @@ class Fp8MoeBackend(Enum):
TRITON = 6
-def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
+def get_fp8_moe_backend(
+ block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
+) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
@@ -159,8 +162,19 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
- # deepGEMM on supported platforms with block-quantized weights
- if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
+ # Determine if we should use DeepGEMM with block-quantized weights:
+ # - If explicitly set by user, respect their choice
+ # - If not explicitly set (default), disable when TP size is >= 8
+ moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
+ if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
+ moe_use_deep_gemm = False
+ logger.info_once(
+ "DeepGEMM MoE is disabled by default when TP size is >= 8. "
+ "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
+ scope="local",
+ )
+
+ if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
@@ -641,7 +655,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
- self.fp8_backend = get_fp8_moe_backend(self.block_quant)
+ self.fp8_backend = get_fp8_moe_backend(
+ self.block_quant, layer.moe_parallel_config
+ )
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
@@ -1140,7 +1156,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1216,31 +1232,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
)
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
- select_result = FusedMoE.select_experts(
+ select_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
- num_fused_shared_experts=layer.num_fused_shared_experts,
)
topk_weights, topk_ids, zero_expert_result = select_result
@@ -1322,7 +1316,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.allow_cutlass_block_scaled_grouped_gemm
),
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index 42d7a67371ae8..bcdfafb50fc5a 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
-
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
@@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_moe_gguf(
x,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 68a122fd46c6b..77b15db373a3a 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `GPTQMarlinMoEMethod` yet."
- )
-
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 01a23168bdde3..2cf7089e0ff90 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -696,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -717,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ModelOptFp8MoEMethod` yet."
- )
-
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptFp8MoEMethod` yet."
+ )
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
@@ -740,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
# Expert selection
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
@@ -1143,6 +1132,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for ModelOptNvFp4FusedMoE."
)
+ elif self.use_marlin:
+ logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
+ else:
+ logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize(
self,
@@ -1459,7 +1452,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1480,16 +1473,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
- )
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
+ )
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@@ -1502,19 +1495,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py
index 2090c86f78dc8..cf348290a2716 100644
--- a/vllm/model_executor/layers/quantization/moe_wna16.py
+++ b/vllm/model_executor/layers/quantization/moe_wna16.py
@@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
-
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index 66ae2e94c60a5..198feb03be3e4 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -132,12 +132,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
)
# If FlashInfer is not available, try either Marlin or Triton
- if (
- envs.VLLM_MXFP4_USE_MARLIN
- or current_platform.get_device_capability()[0] < 9
- or not has_triton_kernels()
- or not is_torch_equal_or_newer("2.8.0")
- ):
+ triton_kernels_supported = (
+ has_triton_kernels()
+ and is_torch_equal_or_newer("2.8.0")
+ # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
+ # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
+ # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
+ and (9, 0) <= current_platform.get_device_capability() < (11, 0)
+ )
+ if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
@@ -862,7 +865,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -887,18 +890,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
)
return fused_marlin_moe(
@@ -989,17 +983,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- e_score_correction_bias=e_score_correction_bias,
)
# Backend-specific preparation
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 30772c3665b06..8be0299eaa66f 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled:
@@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if not self.emulate:
diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py
index 52656263a601b..7b51b828009fc 100644
--- a/vllm/model_executor/layers/quantization/rtn.py
+++ b/vllm/model_executor/layers/quantization/rtn.py
@@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..6b2c1dc1312bf
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 5
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..b0eaf02a541ad
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..4cd357d5086ca
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "512": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..ca2179ddf3d2f
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 5
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "256": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index 7eba8359b92f6..eef7a0896c375 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -282,6 +282,16 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend in backend_map:
+ if (
+ flashinfer_moe_backend == "latency"
+ and not current_platform.is_device_capability(100)
+ ):
+ logger.info_once(
+ "Flashinfer TRTLLM MOE backend is only supported on "
+ "SM100 and later, using CUTLASS backend instead",
+ scope="local",
+ )
+ return FlashinferMoeBackend.CUTLASS
return backend_map[flashinfer_moe_backend]
elif current_platform.is_device_capability(90):
return FlashinferMoeBackend.CUTLASS
diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py
index 152d9401b8e94..0f10bff6ac4f5 100644
--- a/vllm/model_executor/layers/rotary_embedding/__init__.py
+++ b/vllm/model_executor/layers/rotary_embedding/__init__.py
@@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
+from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
@@ -184,6 +185,18 @@ def get_rope(
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
+ elif scaling_type == "xdrope":
+ scaling_alpha = rope_parameters["alpha"]
+ rotary_emb = XDRotaryEmbedding(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ xdrope_section=rope_parameters["xdrope_section"],
+ )
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py
new file mode 100644
index 0000000000000..2432273faf195
--- /dev/null
+++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import numpy as np
+import torch
+
+from .common import apply_rotary_emb_dispatch
+from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
+
+
+class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
+ """DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
+
+ Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
+ """
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ is_neox_style: bool,
+ scaling_alpha: float,
+ dtype: torch.dtype,
+ xdrope_section: list[int],
+ ) -> None:
+ self.xdrope_section = xdrope_section
+ super().__init__(
+ head_size,
+ rotary_dim,
+ max_position_embeddings,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """PyTorch-native implementation equivalent to forward().
+
+ Args:
+ positions:
+ [4, num_tokens] (P/W/H/T positions with multimodal inputs)
+ query: [num_tokens, num_heads * head_size]
+ key: [num_tokens, num_kv_heads * head_size]
+ """
+ assert positions.ndim == 2
+ assert key is not None
+
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ cos = torch.cat(
+ [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+ sin = torch.cat(
+ [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., : self.rotary_dim]
+ query_pass = query[..., self.rotary_dim :]
+ query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., : self.rotary_dim]
+ key_pass = key[..., self.rotary_dim :]
+ key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return query, key
+
+ @staticmethod
+ def get_next_input_positions(
+ context_len: int,
+ seq_len: int,
+ xd_sections: int = 4,
+ ) -> list[list[int]]:
+ return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
+
+ @staticmethod
+ def get_next_input_positions_tensor(
+ out: np.ndarray,
+ out_offset: int,
+ context_len: int,
+ num_new_tokens: int,
+ ):
+ values = np.arange(
+ context_len,
+ context_len + num_new_tokens,
+ dtype=out.dtype,
+ )
+ out[:, out_offset : out_offset + num_new_tokens] = values
diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py
index 2416836be03c4..74052f72ceab9 100644
--- a/vllm/model_executor/model_loader/gguf_loader.py
+++ b/vllm/model_executor/model_loader/gguf_loader.py
@@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading,
)
from vllm.model_executor.model_loader.weight_utils import (
+ download_gguf,
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
@@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
f"load format {load_config.load_format}"
)
- def _prepare_weights(self, model_name_or_path: str):
+ def _prepare_weights(self, model_config: ModelConfig):
+ model_name_or_path = model_config.model
if os.path.isfile(model_name_or_path):
return model_name_or_path
# for raw HTTPS link
@@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader):
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename)
- else:
- raise ValueError(
- f"Unrecognised GGUF reference: {model_name_or_path} "
- "(expected local file, raw URL, or /.gguf)"
+ # repo_id:quant_type
+ elif "/" in model_name_or_path and ":" in model_name_or_path:
+ repo_id, quant_type = model_name_or_path.rsplit(":", 1)
+ return download_gguf(
+ repo_id,
+ quant_type,
+ cache_dir=self.load_config.download_dir,
+ revision=model_config.revision,
+ ignore_patterns=self.load_config.ignore_patterns,
)
+ raise ValueError(
+ f"Unrecognised GGUF reference: {model_name_or_path} "
+ "(expected local file, raw URL, /.gguf, "
+ "or :)"
+ )
+
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
@@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map(
- model_config.model, gguf_to_hf_name_map
+ model_name_or_path, gguf_to_hf_name_map
)
is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal:
@@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
- self._prepare_weights(model_config.model)
+ self._prepare_weights(model_config)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
- local_model_path = self._prepare_weights(model_config.model)
+ local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
@@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
device_config = vllm_config.device_config
- local_model_path = self._prepare_weights(model_config.model)
+ local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
index e74434e9d12cb..2021b68b8a60b 100644
--- a/vllm/model_executor/model_loader/utils.py
+++ b/vllm/model_executor/model_loader/utils.py
@@ -19,12 +19,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
-from vllm.model_executor.models.adapters import (
- as_embedding_model,
- as_reward_model,
- as_seq_cls_model,
- try_create_mm_pooling_model_cls,
-)
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
from vllm.utils.platform_utils import is_pin_memory_available
@@ -172,6 +166,13 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
+ from vllm.model_executor.models.adapters import (
+ as_embedding_model,
+ as_reward_model,
+ as_seq_cls_model,
+ try_create_mm_pooling_model_cls,
+ )
+
architectures = getattr(model_config.hf_config, "architectures", [])
model_cls, arch = model_config.registry.resolve_model_cls(
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 4572ebe2ea11b..0809bdfa9d4c2 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -369,6 +369,52 @@ def get_sparse_attention_config(
return config
+def download_gguf(
+ repo_id: str,
+ quant_type: str,
+ cache_dir: str | None = None,
+ revision: str | None = None,
+ ignore_patterns: str | list[str] | None = None,
+) -> str:
+ # Use patterns that snapshot_download can handle directly
+ # Patterns to match:
+ # - *-{quant_type}.gguf (root)
+ # - *-{quant_type}-*.gguf (root sharded)
+ # - */*-{quant_type}.gguf (subdir)
+ # - */*-{quant_type}-*.gguf (subdir sharded)
+ allow_patterns = [
+ f"*-{quant_type}.gguf",
+ f"*-{quant_type}-*.gguf",
+ f"*/*-{quant_type}.gguf",
+ f"*/*-{quant_type}-*.gguf",
+ ]
+
+ # Use download_weights_from_hf which handles caching and downloading
+ folder = download_weights_from_hf(
+ model_name_or_path=repo_id,
+ cache_dir=cache_dir,
+ allow_patterns=allow_patterns,
+ revision=revision,
+ ignore_patterns=ignore_patterns,
+ )
+
+ # Find the downloaded file(s) in the folder
+ local_files = []
+ for pattern in allow_patterns:
+ # Convert pattern to glob pattern for local filesystem
+ glob_pattern = os.path.join(folder, pattern)
+ local_files.extend(glob.glob(glob_pattern))
+
+ if not local_files:
+ raise ValueError(
+ f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
+ )
+
+ # Sort to ensure consistent ordering (prefer non-sharded files)
+ local_files.sort(key=lambda x: (x.count("-"), x))
+ return local_files[0]
+
+
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: str | None,
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index edf47270e5277..024788918d024 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -233,7 +233,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py
index 3cf4bf991e667..d7e802ba1aca0 100644
--- a/vllm/model_executor/models/config.py
+++ b/vllm/model_executor/models/config.py
@@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
- if cache_config.mamba_block_size is None:
- cache_config.mamba_block_size = model_config.max_model_len
-
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
logger.info(
@@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
)
+ # By default, mamba block size will be set to max_model_len (see
+ # below). When enabling prefix caching, we align mamba block size
+ # to the block size as the basic granularity for prefix caching.
+ if cache_config.mamba_block_size is None:
+ cache_config.mamba_block_size = cache_config.block_size
else:
logger.info(
"Hybrid or mamba-based model detected without "
@@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
)
cache_config.enable_prefix_caching = False
+ if cache_config.mamba_block_size is None:
+ cache_config.mamba_block_size = model_config.max_model_len
+
# TODO(tdoublep): remove once cascade attention is supported
logger.info(
"Disabling cascade attention since it is not supported for hybrid models."
diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py
index e028dc497aa6a..6e23037b919ab 100644
--- a/vllm/model_executor/models/deepseek_mtp.py
+++ b/vllm/model_executor/models/deepseek_mtp.py
@@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Iterable
+import typing
+from collections.abc import Callable, Iterable
import torch
import torch.nn as nn
from transformers import PretrainedConfig
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ rocm_aiter_moe_shared_expert_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
@@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
- num_experts=self.config.n_routed_experts,
+ num_experts=self.config.n_routed_experts
+ + (
+ self.config.n_shared_experts
+ if rocm_aiter_moe_shared_expert_enabled
+ else 0
+ ),
)
params_dict = dict(self.named_parameters())
@@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
+ )
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
+ if is_fusion_moe_shared_experts_layer:
+ continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
@@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight, shard_id)
break
else:
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(
- param,
- loaded_weight,
- name,
- shard_id=shard_id,
- expert_id=expert_id,
+ # Special handling: when AITER fusion_shared_experts is enabled,
+ # checkpoints may provide a single widened shared_experts tensor
+ # without explicit expert indices
+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
+ # For models with multiple shared experts, split that tensor
+ # evenly into per-shared-expert slices and load them into
+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
+ # accordingly.
+ num_chunks = 1
+ if is_fusion_moe_shared_experts_layer:
+ num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
+ # Determine split axis based on op type
+ # gate/up: ColumnParallel → split along dim 0
+ # down: RowParallel → split along dim 1
+ split_dim = 1 if "down_proj.weight" in name else 0
+ total = loaded_weight.shape[split_dim]
+ assert total % num_chunks == 0, (
+ f"Shared expert weight dim {total} "
+ f"not divisible by num_chunks {num_chunks}"
)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
+ chunk_size = total // num_chunks
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
+ for j in range(num_chunks):
+ chunk_name = name
+ weight_to_load = loaded_weight
- # According to DeepSeek-V3 Technical Report, MTP modules
- # shares embedding layer. We only load the first weights.
- if (
- spec_layer != self.model.mtp_start_layer_idx
- and ".layers" not in name
- ):
- continue
+ if is_fusion_moe_shared_experts_layer:
+ if split_dim == 0:
+ weight_to_load = loaded_weight[
+ j * chunk_size : (j + 1) * chunk_size, :
+ ]
+ else:
+ weight_to_load = loaded_weight[
+ :, j * chunk_size : (j + 1) * chunk_size
+ ]
+ # Synthesize an expert-style name so expert mapping
+ # can route it
+ chunk_name = name.replace(
+ "mlp.shared_experts",
+ f"mlp.experts.{self.config.n_routed_experts + j}",
+ )
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
+ # Use expert_params_mapping to locate the destination
+ # param and delegate to its expert-aware weight_loader
+ # with expert_id.
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in chunk_name:
+ continue
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = chunk_name.replace(weight_name, param_name)
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or
+ # not here since otherwise we may skip experts with
+ # other available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ weight_to_load,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ if not is_fusion_moe_shared_experts_layer:
+ name = name_mapped
+ else:
+ loaded_params.add(name_mapped)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ # According to DeepSeek-V3 Technical Report, MTP modules
+ # shares embedding layer. We only load the first weights.
+ if (
+ spec_layer != self.model.mtp_start_layer_idx
+ and ".layers" not in name
+ ):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ if not is_fusion_moe_shared_experts_layer:
+ loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 7cfd381592b49..ad932559b983d 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
- is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
- "mlp.shared_experts" in name
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
@@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
@@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name
weight_to_load = loaded_weight
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
@@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True,
)
if success:
- if not is_fuse_shared_experts_layer:
+ if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
@@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
- if not is_fuse_shared_experts_layer:
+ if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params
diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py
index 2d2251e83b5b1..5460018d0d67a 100644
--- a/vllm/model_executor/models/dots_ocr.py
+++ b/vllm/model_executor/models/dots_ocr.py
@@ -306,7 +306,6 @@ class DotsVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -324,7 +323,6 @@ class DotsVisionAttention(nn.Module):
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
- seqlens: list[int] | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@@ -374,16 +372,6 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
else:
raise RuntimeError("Unsupported attention backend")
@@ -545,14 +533,12 @@ class DotsVisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None,
- seqlens: list[int] | None = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -663,18 +649,14 @@ class DotsVisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
@@ -694,14 +676,13 @@ class DotsVisionTransformer(nn.Module):
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if self.post_trunk_norm is not None:
diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py
index daa5bf03ea4a9..07b34fbc8addb 100644
--- a/vllm/model_executor/models/ernie45_vl.py
+++ b/vllm/model_executor/models/ernie45_vl.py
@@ -214,7 +214,6 @@ class Ernie4_5_VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -259,7 +258,6 @@ class Ernie4_5_VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -311,20 +309,6 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -404,14 +388,12 @@ class Ernie4_5_VisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -562,18 +544,14 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
@@ -598,8 +576,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
if hidden_states.ndim == 2:
hidden_states = hidden_states.unsqueeze(dim=1)
- # pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for i, blk in enumerate(self.blocks):
hidden_states = blk(
@@ -607,7 +585,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
final_output = self.ln(hidden_states)
diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py
index d141e95498064..7e0370886884f 100644
--- a/vllm/model_executor/models/glm4_1v.py
+++ b/vllm/model_executor/models/glm4_1v.py
@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ ) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
- # pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
# adapter
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index e94de8952fa63..bd1bfea3c0fef 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -100,7 +100,7 @@ class GPTJAttention(nn.Module):
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=False,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index 8835acb8ec65c..1bc0ad38765d5 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -656,6 +656,7 @@ class GptOssModel(nn.Module):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
+ is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py
index 4bf23cd6fd19a..cfca564920111 100644
--- a/vllm/model_executor/models/grok1.py
+++ b/vllm/model_executor/models/grok1.py
@@ -239,7 +239,7 @@ class Grok1DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py
index 9fa5e2bd33f21..53fb444ed622d 100644
--- a/vllm/model_executor/models/hunyuan_v1.py
+++ b/vllm/model_executor/models/hunyuan_v1.py
@@ -576,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module):
return hidden_states, residual, ori_kv_states
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py
new file mode 100644
index 0000000000000..e83addd0c092f
--- /dev/null
+++ b/vllm/model_executor/models/hunyuan_vision.py
@@ -0,0 +1,1028 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# coding=utf-8
+# Copyright 2025 The HunYuan team.
+# Copyright 2025 The vLLM team.
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only HunYuan-VL model compatible with HuggingFace weights."""
+
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from functools import partial
+from typing import Annotated, Any, Literal, TypeAlias
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import BatchFeature
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.layer import MultiHeadAttention
+from vllm.config import MultiModalConfig, VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ ImageItem,
+ ModalityData,
+ MultiModalDataDict,
+ MultiModalFeatureSpec,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import (
+ DictEmbeddingItems,
+ ImageSize,
+ MultiModalDataItems,
+ MultiModalDataParser,
+)
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLVisionConfig,
+)
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+ MultiModalEmbeddings,
+ SupportsLoRA,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+)
+from .utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+# === Vision Inputs === #
+
+
+class HunYuanVLImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - np: Number of patches
+ - ni: Number of images
+ - cps: Number of channels * patch_size * patch_size
+ """
+
+ type: Literal["pixel_values"]
+
+ pixel_values: Annotated[
+ torch.Tensor,
+ TensorShape("np", "cps"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+class HunYuanVLImageEmbeddingInputs(TensorSchema):
+ """
+ Dimensions:
+ - nf: Number of image features
+ - hs: Hidden size
+ - ni: Number of images
+ """
+
+ type: Literal["image_embeds"]
+
+ image_embeds: Annotated[
+ torch.Tensor,
+ TensorShape("nf", "hs"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+HunYuanVLImageInputs: TypeAlias = (
+ HunYuanVLImagePixelInputs | HunYuanVLImageEmbeddingInputs
+)
+
+# === Vision Encoder === #
+
+
+class HunYuanVisionMLP(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = True,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.dense_h_to_4h = ColumnParallelLinear(
+ in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_h_to_4h",
+ disable_tp=use_data_parallel,
+ )
+ self.dense_4h_to_h = RowParallelLinear(
+ hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_4h_to_h",
+ disable_tp=use_data_parallel,
+ )
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ x_up, _ = self.dense_h_to_4h(x)
+ x_down, _ = self.dense_4h_to_h(self.act_fn(x_up))
+ return x_down
+
+
+class HunYuanVisionAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ self.tp_size = (
+ 1
+ if use_data_parallel
+ else parallel_state.get_tensor_model_parallel_world_size()
+ )
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads
+ )
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, self.tp_size
+ )
+
+ self.qkv = QKVParallelLinear(
+ hidden_size=embed_dim,
+ head_size=self.hidden_size_per_attention_head,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv",
+ disable_tp=use_data_parallel,
+ )
+
+ self.o_proj = RowParallelLinear(
+ input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ disable_tp=use_data_parallel,
+ )
+
+ self.scale = self.hidden_size_per_attention_head**-0.5
+ self.attn = MultiHeadAttention(
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ self.scale,
+ prefix=f"{prefix}.attn",
+ multimodal_config=multimodal_config,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv(x)
+ q, k, v = qkv.chunk(3, dim=-1)
+ out = self.attn(q, k, v)
+ output, _ = self.o_proj(out)
+ return output
+
+
+class HunYuanVisionBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ norm_layer: Callable[[int], nn.Module] | None = None,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.input_layernorm = norm_layer(dim)
+ self.post_attention_layernorm = norm_layer(dim)
+ self.self_attn = HunYuanVisionAttention(
+ embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.self_attn",
+ use_data_parallel=use_data_parallel,
+ )
+ self.mlp = HunYuanVisionMLP(
+ dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ use_data_parallel=use_data_parallel,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ x = x + self.self_attn(self.input_layernorm(x))
+ x = x + self.mlp(self.post_attention_layernorm(x))
+ return x
+
+
+class HunYuanVisionPatchEmbed(nn.Module):
+ def __init__(self, config: HunYuanVLVisionConfig):
+ super().__init__()
+
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+ self.num_channels = config.num_channels
+ self.spatial_merge_size = config.spatial_merge_size
+ self.interpolate_mode = config.interpolate_mode
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.max_num_patches = (config.max_image_size // self.patch_size) ** 2
+
+ self.num_positions = self.max_num_patches + 1
+ self.position_edge = int(self.num_positions**0.5)
+ # first token is cls token, skip it
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ self.patch_pos_embed = None
+
+ def forward(
+ self, pixel_values: torch.Tensor, grid_thw: list[list[int]]
+ ) -> torch.Tensor:
+ num_patches = pixel_values.size(0)
+ pixel_values = pixel_values.reshape(
+ num_patches, self.num_channels, self.patch_size, self.patch_size
+ )
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ patch_embeds = patch_embeds.squeeze(-1).squeeze(-1).unsqueeze(0)
+
+ if self.patch_pos_embed is None:
+ patch_pos_shape = (
+ 1,
+ self.position_edge,
+ self.position_edge,
+ self.embed_dim,
+ )
+ self.patch_pos_embed = (
+ self.position_embedding.weight[1:, :]
+ .reshape(patch_pos_shape)
+ .permute(0, 3, 1, 2)
+ .float()
+ )
+
+ patch_pos_embed_list = []
+ for grid in grid_thw:
+ _, h0, w0 = grid
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ self.patch_pos_embed,
+ scale_factor=(h0 / self.position_edge, w0 / self.position_edge),
+ mode=self.interpolate_mode,
+ align_corners=False,
+ )
+
+ patch_pos_embed = (
+ patch_pos_embed.reshape(self.embed_dim, -1)
+ .transpose(0, 1)
+ .unsqueeze(0)
+ .to(patch_embeds.dtype)
+ )
+ patch_pos_embed_list.append(patch_pos_embed)
+
+ patch_pos_embed = torch.cat(patch_pos_embed_list, dim=1)
+ embeddings = patch_embeds + patch_pos_embed
+
+ return embeddings
+
+
+class HunYuanVisionPatchMerger(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ spatial_merge_size=2,
+ rms_norm_eps=1e-5,
+ prefix="",
+ ):
+ super().__init__()
+ self.spatial_merge_size = spatial_merge_size
+ embed_std = out_channels**-0.5
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ in_channels * 2,
+ kernel_size=spatial_merge_size,
+ stride=spatial_merge_size,
+ ),
+ nn.GELU(),
+ nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
+ )
+ self.mlp = nn.Linear(in_channels * 4, out_channels)
+
+ self.image_newline = nn.Parameter(torch.randn(in_channels * 4) * embed_std)
+ self.image_begin = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_end = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_sep = nn.Parameter(torch.randn(out_channels) * embed_std)
+
+ self.before_rms = RMSNorm(in_channels, eps=rms_norm_eps)
+ self.after_rms = RMSNorm(out_channels, eps=rms_norm_eps)
+
+ def forward(self, x, size=(16, 16)):
+ x = self.before_rms(x)
+
+ h, w = size
+ dtype = x.dtype
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
+
+ x = self.proj(x) # b,c,h,w
+ b, c, h, w = x.shape
+ x = torch.cat(
+ [x, self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype)],
+ dim=-1,
+ )
+ x = x.reshape(b, c, -1).permute(0, 2, 1)
+ x = self.mlp(x)
+
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ x = torch.cat([begin, x, end], dim=1)
+
+ return self.after_rms(x)
+
+
+class HunYuanVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ vision_config: HunYuanVLVisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ multimodal_config: MultiModalConfig | None = None,
+ attn_backend_override: AttentionBackendEnum | None = None,
+ ) -> None:
+ super().__init__()
+
+ num_hidden_layers = vision_config.num_hidden_layers
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_attention_heads
+ self.spatial_merge_size = vision_config.spatial_merge_size
+
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("HunYuanVisionPatchEmbed"):
+ self.embeddings = HunYuanVisionPatchEmbed(vision_config)
+
+ norm_layer = partial(nn.LayerNorm, eps=vision_config.rms_norm_eps)
+
+ with set_model_tag("HunYuanVisionBlock"):
+ self.layers = nn.ModuleList(
+ [
+ HunYuanVisionBlock(
+ dim=vision_config.hidden_size,
+ num_heads=vision_config.num_attention_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=get_act_fn(vision_config.hidden_act),
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ use_data_parallel=use_data_parallel,
+ )
+ for layer_idx in range(num_hidden_layers)
+ ]
+ )
+
+ with set_model_tag("HunYuanVisionPatchMerger"):
+ self.perceive = HunYuanVisionPatchMerger(
+ vision_config.hidden_size,
+ vision_config.out_hidden_size,
+ spatial_merge_size=vision_config.spatial_merge_size,
+ rms_norm_eps=vision_config.rms_norm_eps,
+ prefix=f"{prefix}.perceive",
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.embeddings.patch_embedding.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.embeddings.patch_embedding.weight.device
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: list[list[int]],
+ ) -> torch.Tensor:
+ # patchify
+ seq_len = x.size(0)
+ cu_seqlens: list = [0]
+
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.embeddings(hidden_states, grid_thw)
+
+ for t, h, w in grid_thw:
+ t, h, w = int(t), int(h), int(w)
+ cu_seqlens.append(h * w)
+
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
+ cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
+
+ cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
+
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ hidden_states = hidden_states.unsqueeze(0)
+ for layer_num, layer in enumerate(self.layers):
+ hidden_states = layer(hidden_states)
+
+ # adapter
+ split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ split_items = hidden_states.split(split_lengths, dim=1)
+ image_embeds_list = []
+ for grid, split_item in zip(grid_thw, split_items):
+ image_embeds_list.append(
+ self.perceive(split_item.contiguous(), size=grid[1:]).squeeze(0)
+ )
+
+ return image_embeds_list
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv", ".q_proj", "q"),
+ (".qkv", ".k_proj", "k"),
+ (".qkv", ".v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ image_grid_sizes = image_grid_thw.prod(-1)
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_grid_thw=MultiModalFieldConfig.batched("image"),
+ )
+
+
+class HunYuanVLMultiModalDataParser(MultiModalDataParser):
+ def _parse_image_data(
+ self,
+ data: dict[str, torch.Tensor] | ModalityData[ImageItem],
+ ):
+ if isinstance(data, dict):
+ return DictEmbeddingItems(
+ data,
+ modality="image",
+ required_fields={"image_embeds", "image_grid_thw"},
+ fields_factory=_hunyuan_vl_field_config,
+ )
+
+ return super()._parse_image_data(data)
+
+
+class HunYuanVLProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(HunYuanVLConfig)
+
+ def get_hf_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.ctx.get_hf_processor(
+ HunYuanVLProcessor,
+ use_fast=kwargs.pop("use_fast", True),
+ **kwargs,
+ )
+
+ def get_image_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ max_image_tokens = self.get_max_image_tokens()
+ # TODO: support video
+ max_video_tokens = 0
+ return {"image": max_image_tokens, "video": max_video_tokens}
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ do_resize: bool = True,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> tuple[ImageSize, int]:
+ if image_processor is None:
+ image_processor = self.get_image_processor()
+
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ patch_size = vision_config.patch_size
+ spatial_merge_size = vision_config.spatial_merge_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * spatial_merge_size,
+ min_pixels=image_processor.min_pixels,
+ max_pixels=image_processor.max_pixels,
+ )
+ preprocessed_size = ImageSize(width=resized_width, height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width, height=image_height)
+
+ grid_t = 1
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_vision_tokens = (
+ grid_t * grid_h // spatial_merge_size * (grid_w // spatial_merge_size + 1)
+ + 2
+ )
+
+ return preprocessed_size, num_vision_tokens
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> int:
+ _, num_image_tokens = self._get_vision_info(
+ image_width=image_width,
+ image_height=image_height,
+ image_processor=image_processor,
+ )
+ return num_image_tokens
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ max_image_size, _ = self._get_vision_info(
+ image_width=512,
+ image_height=8192,
+ image_processor=None,
+ )
+ return max_image_size
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ image_processor=None,
+ )
+
+
+class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ hf_processor = self.info.get_hf_processor()
+ image_token: str = hf_processor.image_token
+
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 1)
+
+ target_width, target_height = self.info.get_image_size_with_most_features()
+
+ return {
+ "image": self._get_dummy_images(
+ width=target_width, height=target_height, num_images=num_images
+ ),
+ }
+
+
+class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]):
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return HunYuanVLMultiModalDataParser()
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ return self.info.ctx.call_hf_processor(
+ self.info.get_hf_processor(**mm_kwargs),
+ dict(text=prompt, **mm_data),
+ dict(**mm_kwargs, **tok_kwargs),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
+
+ placeholder = {
+ "image": hf_processor.image_token_id,
+ }
+
+ merge_size = image_processor.merge_size
+
+ def get_replacement_hunyuan_vl(item_idx: int, modality: str):
+ out_item = out_mm_kwargs[modality][item_idx]
+ grid_thw = out_item[f"{modality}_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+
+ _, grid_h, grid_w = grid_thw
+ num_tokens = (int(grid_h) // merge_size) * (
+ int(grid_w) // merge_size + 1
+ ) + 2
+ return [placeholder[modality]] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=[placeholder[modality]],
+ replacement=partial(get_replacement_hunyuan_vl, modality=modality),
+ )
+ for modality in ("image",)
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _hunyuan_vl_field_config(hf_inputs)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ HunYuanVLMultiModalProcessor,
+ info=HunYuanVLProcessingInfo,
+ dummy_inputs=HunYuanVLDummyInputsBuilder,
+)
+class HunYuanVLForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsLoRA,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+):
+ multimodal_cpu_fields = {"image_grid_thw"}
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # mapping for new names in checkpoint saved after transformers v4.52
+ "vit.vit.": "visual.",
+ "vit.": "visual.",
+ "model.": "language_model.model.",
+ }
+ )
+
+ supports_encoder_tp_data = True
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list[MultiModalFeatureSpec],
+ ) -> torch.Tensor:
+ kwargs = MultiModalFeatureSpec.gather_kwargs(
+ mm_features,
+ {"image_grid_thw"},
+ )
+ image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
+
+ hf_config = self.config
+ image_start_token_id = hf_config.image_start_token_id
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
+ xd_num = len(hf_config.rope_scaling["xdrope_section"])
+
+ input_tokens_tensor = torch.tensor(input_tokens)
+ image_start_indices = torch.argwhere(
+ input_tokens_tensor == image_start_token_id
+ ).squeeze(1)
+
+ p_index = torch.arange(len(input_tokens_tensor))
+ w_index = torch.arange(len(input_tokens_tensor))
+ h_index = torch.arange(len(input_tokens_tensor))
+ t_index = torch.arange(len(input_tokens_tensor))
+ for image_index in range(len(image_start_indices)):
+ # +1 : first image_token, +2: for xdrope positions
+ pos = image_start_indices[image_index] + 2
+ t, h, w = image_grid_thw[image_index]
+ _, llm_grid_h, llm_grid_w = (
+ t,
+ h // spatial_merge_size,
+ w // spatial_merge_size,
+ )
+
+ token_num = (llm_grid_w + 1) * llm_grid_h
+ w_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_w + 1)
+ .reshape(1, -1)
+ .expand(llm_grid_h, -1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_h)
+ .reshape(-1, 1)
+ .expand(-1, llm_grid_w + 1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num] = 0
+
+ if xd_num == 4:
+ llm_positions = torch.stack([p_index, w_index, h_index, t_index])
+ elif xd_num == 3:
+ llm_positions = torch.stack([w_index, h_index, t_index])
+
+ return llm_positions
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: HunYuanVLConfig = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ if multimodal_config.get_limit_per_prompt("image"):
+ attn_backend_override = (
+ multimodal_config.mm_encoder_attn_backend
+ if multimodal_config is not None
+ else None
+ )
+ self.visual = HunYuanVisionTransformer(
+ config.vision_config,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ multimodal_config=multimodal_config,
+ attn_backend_override=attn_backend_override,
+ )
+ else:
+ self.visual = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model.model"),
+ architectures=[
+ "HunYuanDenseV1ForCausalLM",
+ "HunYuanMoEV1ForCausalLM",
+ ],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object
+ ) -> HunYuanVLImageInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ # TODO: refine
+ if isinstance(pixel_values, list):
+ pixel_values = torch.cat(pixel_values, dim=0)
+ if len(pixel_values.shape) == 3:
+ last_dim = pixel_values.shape[-1]
+ pixel_values = pixel_values.reshape(-1, last_dim)
+ image_grid_thw = image_grid_thw.reshape(-1, 3)
+
+ if pixel_values is not None:
+ return HunYuanVLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ if image_embeds is not None:
+ return HunYuanVLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw,
+ )
+
+ def _process_image_input(
+ self, image_input: HunYuanVLImageInputs
+ ) -> tuple[torch.Tensor, ...]:
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+ grid_thw_list = grid_thw.tolist()
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"]
+
+ # TODO: use_data_parallel (split image_embeds in visual)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
+
+ return image_embeds
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if (
+ input_key in ("pixel_values", "image_embeds")
+ and "image" not in mm_input_by_modality
+ ):
+ mm_input_by_modality["image"] = self._parse_and_validate_image_input(
+ **kwargs
+ )
+ return mm_input_by_modality
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ image_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings += tuple(image_embeddings)
+ return multimodal_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None,
+ inputs_embeds: torch.Tensor | None,
+ **kwargs: object,
+ ) -> torch.Tensor | IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model.model",
+ connector="visual.perceive",
+ tower_model="visual",
+ )
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index 9966498e1b4c9..6f6ce32538b71 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -1047,7 +1047,7 @@ class SupportsMRoPE(Protocol):
supports_mrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports M-RoPE.
-
+
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
@@ -1088,3 +1088,52 @@ def supports_mrope(
model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
return isinstance(model, SupportsMRoPE)
+
+
+@runtime_checkable
+class SupportsXDRoPE(Protocol):
+ """The interface required for all models that support XD-RoPE."""
+
+ supports_xdrope: ClassVar[Literal[True]] = True
+ """
+ A flag that indicates this model supports XD-RoPE.
+
+ Note:
+ There is no need to redefine this flag if this class is in the
+ XDRope of your model class.
+ """
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list["MultiModalFeatureSpec"],
+ ) -> torch.Tensor:
+ """
+ Get XD-RoPE input positions and delta value for this specific model.
+
+ This method should be implemented by each model that supports XD-RoPE
+ to provide model-specific logic for computing input positions.
+
+ Args:
+ input_tokens: List of input token IDs
+ mm_features: Information about each multi-modal data item
+
+ Returns:
+ llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
+ 4D(P/W/H/T) or 3D(W/H/T) positions.
+ """
+ ...
+
+
+@overload
+def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
+
+
+@overload
+def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
+
+
+def supports_xdrope(
+ model: type[object] | object,
+) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
+ return isinstance(model, SupportsXDRoPE)
diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py
index 8fc3db296aa79..302260b952992 100644
--- a/vllm/model_executor/models/keye.py
+++ b/vllm/model_executor/models/keye.py
@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
- AttentionBackendEnum.XFORMERS,
+ AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0]
if rope_emb is None:
@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
+ elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
+ outputs = []
+ for i in range(1, len(cu_seqlens)):
+ start_idx = cu_seqlens[i - 1]
+ end_idx = cu_seqlens[i]
+ q_i = q[:, start_idx:end_idx]
+ k_i = k[:, start_idx:end_idx]
+ v_i = v[:, start_idx:end_idx]
+ q_i, k_i, v_i = (
+ rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
+ )
+ output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
+ output_i = rearrange(output_i, "b h s d -> b s h d ")
+ outputs.append(output_i)
+ context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index ebf8addda4a54..f6af2bb3b12e9 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -262,7 +262,7 @@ class LlamaAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
@@ -354,7 +354,17 @@ class LlamaDecoderLayer(nn.Module):
return vllm_config.quant_config
-@support_torch_compile
+def llama_model_invariants(
+ input_ids, positions, intermediate_tensors=None, inputs_embeds=None
+):
+ """Shape invariants for Llama model compilation, those are translated to
+ runtime assertions for unbacked dynamic shapes and are compiled away for
+ backed"""
+ if input_ids is not None:
+ torch._check(positions.size()[0] == input_ids.size()[0])
+
+
+@support_torch_compile(shape_invariants=llama_model_invariants)
class LlamaModel(nn.Module):
def __init__(
self,
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 3eaf2d80082f1..7a57644db1b13 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -142,6 +142,12 @@ class LlamaModel(nn.Module):
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
+ eagle_config = getattr(self.config, "eagle_config", None)
+ if eagle_config is not None and "use_aux_hidden_state" in eagle_config:
+ self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
+ else:
+ self.use_aux_hidden_state = True
+
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@@ -161,20 +167,20 @@ class LlamaModel(nn.Module):
for layer_idx in range(self.config.num_hidden_layers)
]
)
- if hasattr(self.config, "target_hidden_size"):
- fc_input_size = self.config.target_hidden_size * 3
- else:
- fc_input_size = self.config.hidden_size * 3
- self.fc = ReplicatedLinear(
- input_size=fc_input_size,
- output_size=self.config.hidden_size,
- bias=False,
- params_dtype=vllm_config.model_config.dtype,
- quant_config=self.quant_config,
- prefix=maybe_prefix(prefix, "fc"),
- return_bias=False,
- )
-
+ if self.use_aux_hidden_state:
+ if hasattr(self.config, "target_hidden_size"):
+ fc_input_size = self.config.target_hidden_size * 3
+ else:
+ fc_input_size = self.config.hidden_size * 3
+ self.fc = ReplicatedLinear(
+ input_size=fc_input_size,
+ output_size=self.config.hidden_size,
+ bias=False,
+ params_dtype=vllm_config.model_config.dtype,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "fc"),
+ return_bias=False,
+ )
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@@ -332,6 +338,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
+ if not self.model.use_aux_hidden_state:
+ return hidden_states
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
@@ -357,6 +365,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
+ if not self.model.use_aux_hidden_state:
+ skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py
index 2e3e6dc166ad8..63ea6b259a71d 100644
--- a/vllm/model_executor/models/moonvit.py
+++ b/vllm/model_executor/models/moonvit.py
@@ -56,10 +56,13 @@ from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
+from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
+elif current_platform.is_xpu():
+ from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
@@ -106,10 +109,10 @@ def multihead_attention(
q,
k,
v,
- q_cu_seqlens,
- k_cu_seqlens,
- max_seqlen_q,
- max_seqlen_k,
+ cu_seqlens_q=q_cu_seqlens,
+ cu_seqlens_k=k_cu_seqlens,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
@@ -291,7 +294,12 @@ class Rope2DPosEmb(nn.Module):
"""
def __init__(
- self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
+ self,
+ dim: int,
+ max_height: int,
+ max_width: int,
+ theta_base=10000,
+ device=current_platform.device_type,
):
super().__init__()
self.dim = dim
@@ -437,7 +445,7 @@ class MoonVitEncoderLayer(nn.Module):
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
# use fa2 in vllm by default
- if is_flash_attn_2_available():
+ if is_flash_attn_2_available() or current_platform.is_xpu():
self.attn_implementation = "flash_attention_2"
self.norm0 = nn.LayerNorm(hidden_dim)
diff --git a/vllm/model_executor/models/opencua.py b/vllm/model_executor/models/opencua.py
new file mode 100644
index 0000000000000..121bf896fa6ba
--- /dev/null
+++ b/vllm/model_executor/models/opencua.py
@@ -0,0 +1,271 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+#
+# Adapted from Qwen2.5-VL implementation
+# Copyright 2025 The vLLM team.
+# Copyright 2025 XLANG Lab, The University of Hong Kong
+
+"""Inference-only OpenCUA-7B model compatible with HuggingFace weights."""
+
+from collections.abc import Mapping, Sequence
+from typing import Any
+
+import torch
+import torch.nn as nn
+from transformers import BatchFeature
+from transformers.models.qwen2_vl import (
+ Qwen2VLImageProcessor,
+ Qwen2VLProcessor,
+ Qwen2VLVideoProcessor,
+)
+
+from vllm.config import VllmConfig
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalFieldConfig,
+ MultiModalKwargs,
+)
+from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+from .qwen2_5_vl import (
+ Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
+)
+from .qwen2_5_vl import (
+ Qwen2_5_VLForConditionalGeneration,
+)
+from .qwen2_vl import (
+ Qwen2VLDummyInputsBuilder,
+ Qwen2VLMultiModalDataParser,
+ Qwen2VLProcessingInfo,
+ _create_qwen2vl_field_factory,
+)
+from .utils import (
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+
+class OpenCUAProcessingInfo(Qwen2VLProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config()
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_hf_processor(self, **kwargs: object):
+ """Load OpenCUA processor."""
+ tokenizer = self.get_tokenizer()
+ vision_config = self.ctx.get_hf_image_processor_config()
+ return OpenCUAProcessor(
+ vision_config=vision_config,
+ tokenizer=tokenizer,
+ **kwargs,
+ )
+
+
+class OpenCUAProcessor(Qwen2VLProcessor):
+ def check_argument_for_proper_class(self, attribute_name: str, arg: object) -> None:
+ if attribute_name == "tokenizer":
+ return
+ return super().check_argument_for_proper_class(attribute_name, arg)
+
+ def __init__(
+ self,
+ vision_config: dict,
+ tokenizer: AnyTokenizer,
+ **kwargs,
+ ):
+ image_processor = Qwen2VLImageProcessor(**vision_config)
+ video_processor = Qwen2VLVideoProcessor(**vision_config)
+ chat_template = kwargs.pop("chat_template", None)
+
+ super().__init__(
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ video_processor=video_processor,
+ chat_template=chat_template,
+ **kwargs,
+ )
+
+ self.image_token = "<|media_placeholder|>"
+
+ def __call__(
+ self,
+ text=None,
+ images=None,
+ return_tensors=None,
+ **kwargs,
+ ):
+ if text is not None:
+ if not isinstance(text, list):
+ text = [text]
+ text_inputs = self.tokenizer(text, **kwargs)
+ else:
+ text_inputs = {}
+
+ image_inputs = {}
+ if images is not None:
+ if not isinstance(images, list):
+ images = [images]
+ if len(images) > 0:
+ image_inputs = self.image_processor(
+ images, return_tensors=return_tensors or "pt"
+ )
+
+ combined_inputs = {**text_inputs, **image_inputs}
+
+ return BatchFeature(combined_inputs, tensor_type=return_tensors)
+
+
+class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo]):
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return Qwen2VLMultiModalDataParser(
+ self.info.get_hf_config().vision_config.spatial_merge_size
+ )
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _create_qwen2vl_field_factory(
+ self.info.get_hf_config().vision_config.spatial_merge_size
+ )(hf_inputs)
+
+ def _hf_processor_applies_updates(
+ self,
+ prompt_text: str,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ tokenization_kwargs: Mapping[str, object],
+ ) -> bool:
+ """vLLM이 prompt 업데이트를 처리하도록 False 반환."""
+ return False
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+ hf_config = self.info.get_hf_config()
+
+ image_token_str = getattr(hf_processor, "image_token", "<|media_placeholder|>")
+ image_token_id = vocab.get(
+ image_token_str,
+ getattr(hf_config, "media_placeholder_token_id", 151664),
+ )
+
+ merge_length = image_processor.merge_size**2
+
+ def get_replacement_opencua(item_idx: int):
+ out_item = out_mm_kwargs["image"][item_idx]
+ grid_thw = out_item["image_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+
+ num_tokens = int(grid_thw.prod()) // merge_length
+ return [image_token_id] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_token_id],
+ replacement=get_replacement_opencua,
+ )
+ ]
+
+
+class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ image_token = "<|media_placeholder|>"
+
+ return image_token * num_images
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ OpenCUAMultiModalProcessor,
+ info=OpenCUAProcessingInfo,
+ dummy_inputs=OpenCUADummyInputsBuilder,
+)
+class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
+ merge_by_field_config = True
+ multimodal_cpu_fields = {"image_grid_thw"}
+
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "model.language_model.": "language_model.model.",
+ "model.visual.": "visual.",
+ "vision_tower.": "visual.",
+ "lm_head.": "language_model.lm_head.",
+ "model.": "language_model.model.",
+ }
+ )
+
+ supports_encoder_tp_data = True
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return "<|media_placeholder|>"
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ nn.Module.__init__(self)
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
+ self.config = config
+ self.vllm_config = vllm_config
+ self.multimodal_config = multimodal_config
+ self.quant_config = quant_config
+ self.is_multimodal_pruning_enabled = (
+ multimodal_config.is_multimodal_pruning_enabled()
+ )
+
+ if multimodal_config.get_limit_per_prompt("image"):
+ attn_backend_override = (
+ multimodal_config.mm_encoder_attn_backend
+ if multimodal_config is not None
+ else None
+ )
+ self.visual = OpenCUAVisionTransformer(
+ vision_config=config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ use_data_parallel=self.use_data_parallel,
+ attn_backend_override=attn_backend_override,
+ )
+ else:
+ self.visual = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Qwen2ForCausalLM"],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py
index dee0c16ab0f63..74bb868492da9 100644
--- a/vllm/model_executor/models/paddleocr_vl.py
+++ b/vllm/model_executor/models/paddleocr_vl.py
@@ -38,7 +38,6 @@ from vllm.attention.layer import (
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
- vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
- seqlens: torch.Tensor | None,
) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape
@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- if seqlens is None:
- raise ValueError("xFormers attention backend requires seqlens tensor.")
- context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
- seqlens: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = residual + hidden_states
@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None
- seqlens = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
return hidden_states
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index 8a034fd72b02a..6011d93a795d1 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -74,6 +74,7 @@ from .vision import (
)
try:
+ # Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 32b6d6dd07b83..5831ce0b3d64b 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -274,6 +274,38 @@ class Qwen2DecoderLayer(nn.Module):
return hidden_states, residual
+def qwen_2_model_invariants(
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+):
+ """Shape invariants for Qwen2Model Model, those are translated to
+ runtime assertions for unbacked dynamic shapes and are compiled away for
+ backed"""
+ # All these should be equal.
+ # input_ids.size()[0]
+ # positions.size()[-1]
+ # intermediate_tensors["hidden_states"].size()[0]
+ # inputs_embeds.size()[0]
+ torch._check(input_ids.size()[0] == positions.size()[-1])
+ if intermediate_tensors is not None:
+ torch._check(
+ input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
+ )
+
+ if inputs_embeds is not None:
+ torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
+
+ # Hidden dimensions should match (hidden_size)
+ # intermediate_tensors["hidden_states"].size()[1]
+ # inputs_embeds.size()[1]
+ if inputs_embeds is not None and intermediate_tensors is not None:
+ torch._check(
+ inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
+ )
+
+
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
@@ -282,7 +314,8 @@ class Qwen2DecoderLayer(nn.Module):
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
- }
+ },
+ shape_invariants=qwen_2_model_invariants,
)
class Qwen2Model(nn.Module):
def __init__(
diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py
index 262ea771d9cdf..7506ee8656fda 100644
--- a/vllm/model_executor/models/qwen2_5_omni_thinker.py
+++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py
@@ -23,7 +23,6 @@
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence
-from copy import copy
from functools import partial
from typing import Annotated, Any, Literal
@@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
- use_audio_in_video = False
- if "video" in mm_kwargs:
- video_items = [item for item in mm_kwargs["video"] if item is not None]
- # only check video items (if there are any)
- if video_items:
- use_audio_in_video = all(
- item["use_audio_in_video"].data for item in video_items
- )
-
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
@@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
- use_audio_in_video=use_audio_in_video,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
@@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
- use_audio_in_video=use_audio_in_video,
)
return prompt_ids, mm_placeholders
@@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
return mm_processed_data
- def _validate_mm_placeholders(
- self,
- mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
- mm_item_counts: Mapping[str, int],
- use_audio_in_video: bool = False,
- ) -> None:
- if use_audio_in_video:
- mm_item_counts = copy(mm_item_counts)
- if "video" in mm_item_counts:
- assert "audio" in mm_item_counts
- mm_item_counts["audio"] -= mm_item_counts["video"]
- super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
-
class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_audio_input(
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 8e3c0e84dfe51..8c707c2561af1 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
- vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -230,6 +229,9 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- hidden_size must match the hidden size of language model backbone.
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
+ - second_per_grid_ts: The video time interval (in seconds) for each
+ grid along the temporal dimension in the 3D position IDs. Returned
+ when `videos` is not `None`.
"""
type: Literal["video_embeds"]
@@ -244,6 +246,11 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
TensorShape("nv", 3),
]
+ second_per_grid_ts: Annotated[
+ torch.Tensor | None,
+ TensorShape("nv"),
+ ] = None
+
Qwen2_5_VLVideoInputs: TypeAlias = (
Qwen2_5_VLVideoPixelInputs | Qwen2_5_VLVideoEmbeddingInputs
@@ -367,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -427,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v,
cu_seqlens,
)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
return output
@@ -440,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
- "seqlens": 0,
},
- mark_unbacked_dims={"seqlens": 0},
enable_if=should_torch_compile_mm_vit,
)
class Qwen2_5_VisionBlock(nn.Module):
@@ -493,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -501,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -662,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -814,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
@@ -889,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
- max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
- max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
- cu_window_seqlens
- )
+ max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
@@ -919,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
- seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
- seqlens_now = seqlens_window
hidden_states = blk(
hidden_states,
@@ -931,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
- seqlens=seqlens_now,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block
@@ -1311,6 +1302,7 @@ class Qwen2_5_VLForConditionalGeneration(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
)
def _process_image_input(
@@ -1422,7 +1414,13 @@ class Qwen2_5_VLForConditionalGeneration(
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
- second_per_grid_ts = video_input["second_per_grid_ts"].long()
+ second_per_grid_ts = video_input.get("second_per_grid_ts")
+ if second_per_grid_ts is None:
+ raise ValueError(
+ "second_per_grid_ts is required when video_pruning_rate > 0 "
+ "is enabled for video inputs, including the video_embeds path."
+ )
+ second_per_grid_ts = second_per_grid_ts.long()
tokens_per_second = self.config.vision_config.tokens_per_second
video_embeds_out = []
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 479a7871e364f..9d1d023aed172 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x)
@@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks:
x = blk(
@@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
# adapter
diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
index 54ef56f83344e..f5f88f66eff91 100755
--- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py
+++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
@@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
- BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
+ PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
@@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
- Qwen2_5OmniThinkerProcessingInfo,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
@@ -224,7 +223,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -232,7 +230,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -500,14 +497,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -533,7 +527,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@@ -545,7 +539,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if (
deepstack_visual_indexes is not None
@@ -813,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
else:
use_audio_in_video = False
- if use_audio_in_video and "video" in mm_item_counts:
- assert "audio" in mm_item_counts
- mm_item_counts["audio"] -= mm_item_counts["video"]
-
- # Special case with `use_audio_in_video=True`
- if use_audio_in_video:
- if is_update_applied:
- prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
- (
- prompt_ids,
- mm_placeholders,
- ) = self._apply_prompt_updates(
- prompt_ids,
- mm_prompt_updates,
- )
- self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
# normal case with `use_audio_in_video=False`
- elif is_update_applied:
+ if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
mm_prompt_updates,
@@ -840,10 +817,24 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts,
)
else:
- prompt_ids, mm_placeholders = self._apply_prompt_updates(
- prompt_ids,
- mm_prompt_updates,
- )
+ if use_audio_in_video and "audio" in mm_prompt_updates:
+ filtered_updates = {
+ k: v for k, v in mm_prompt_updates.items() if k != "audio"
+ }
+ prompt_ids, mm_placeholders = self._apply_prompt_updates(
+ prompt_ids,
+ filtered_updates,
+ )
+ # Derive audio placeholders from video placeholders
+ mm_placeholders = self._derive_audio_from_video_placeholders(
+ mm_placeholders, mm_prompt_updates
+ )
+ else:
+ prompt_ids, mm_placeholders = self._apply_prompt_updates(
+ prompt_ids,
+ mm_prompt_updates,
+ )
+
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
@@ -968,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
nonlocal audio_in_video_item_idx
- audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
+ audio_num_features = audio_output_lengths[
+ audio_in_video_item_idx + item_idx
+ ]
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
audio_in_video_item_idx += 1
@@ -977,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[item_idx]
else:
- video_second_per_grid_t = 1.0
+ video_second_per_grid_t = 2.0
- return self.get_updates_use_audio_in_video(
+ placeholder = self.get_updates_use_audio_in_video(
thinker_config=thinker_config,
audio_len=audio_num_features,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
+ return PromptUpdateDetails.select_token_id(
+ placeholder, embed_token_id=video_token_id
+ )
video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video
@@ -1010,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
),
]
- def _validate_mm_placeholders(
+ def _derive_audio_from_video_placeholders(
self,
- mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
- mm_item_counts: Mapping[str, int],
- ) -> None:
- BaseMultiModalProcessor[
- Qwen2_5OmniThinkerProcessingInfo
- ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)
+ placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
+ mm_prompt_updates: MultiModalPromptUpdates,
+ ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
+ """
+ Helper to derive audio placeholders from video placeholders when
+ use_audio_in_video=True.
+ """
+ if "video" not in placeholders:
+ return placeholders
+
+ # Validate audio and video counts match
+ num_videos = len(placeholders["video"])
+ num_audios = len(mm_prompt_updates.get("audio", []))
+ if num_audios != num_videos:
+ raise ValueError(
+ f"use_audio_in_video requires equal number of audio and video items, "
+ f"got {num_audios=}, {num_videos=}"
+ )
+
+ tokenizer = self.info.get_tokenizer()
+ processor = self.info.get_hf_processor()
+ audio_token_id = tokenizer.get_vocab()[processor.audio_token]
+
+ result_placeholders = dict(placeholders)
+ audio_placeholders = []
+
+ # Each video is paired with one audio
+ for video_idx, video_placeholder in enumerate(placeholders["video"]):
+ # Create is_embed mask selecting only audio tokens
+ audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
+
+ audio_placeholder = PlaceholderFeaturesInfo(
+ modality="audio",
+ item_idx=video_idx,
+ start_idx=video_placeholder.start_idx,
+ tokens=video_placeholder.tokens,
+ is_embed=audio_is_embed,
+ )
+ audio_placeholders.append(audio_placeholder)
+
+ result_placeholders["audio"] = audio_placeholders
+ return result_placeholders
def _get_raw_input_ids(
self,
@@ -1460,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
if not len(second_per_grid_ts) and len(video_grid_thw):
- second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
+ second_per_grid_ts = 2.0
+ second_per_grids = (
+ torch.ones(len(video_grid_thw), dtype=torch.float32)
+ * second_per_grid_ts
+ )
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py
index 90c4894d33e88..4cd6fa14c32df 100644
--- a/vllm/model_executor/models/qwen3_vl.py
+++ b/vllm/model_executor/models/qwen3_vl.py
@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 4943987606201..a0d8a78a2ae76 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -287,8 +287,16 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration",
),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
+ "HunYuanVLForConditionalGeneration": (
+ "hunyuan_vision",
+ "HunYuanVLForConditionalGeneration",
+ ),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
+ "OpenCUAForConditionalGeneration": (
+ "opencua",
+ "OpenCUAForConditionalGeneration",
+ ),
"InternS1ForConditionalGeneration": (
"interns1",
"InternS1ForConditionalGeneration",
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index f9bf242b7194e..75b6bc77e4c1c 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -277,12 +277,7 @@ class CudaPlatformBase(Platform):
except ImportError:
pass
- if cls.has_device_capability(100):
- # xFormers doesn't support Blackwell, fall back to SDPA
- # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
- return AttentionBackendEnum.TORCH_SDPA
- else:
- return AttentionBackendEnum.XFORMERS
+ return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(
@@ -412,9 +407,6 @@ class CudaPlatformBase(Platform):
# We have found some valid backends. Select the one with the
# highest priority.
- logger.info(
- "Valid backends: %s", [b[0].name for b in valid_backends_priorities]
- )
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
@@ -422,8 +414,9 @@ class CudaPlatformBase(Platform):
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info(
- "Using %s backend.",
+ "Using %s attention backend out of potential backends: %s",
selected_backend.name,
+ [b[0].name for b in valid_backends_priorities],
)
return selected_backend.get_path()
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 0471c20429b1d..1e6b53021f888 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -134,6 +134,11 @@ class Platform:
_global_graph_pool: Any | None = None
+ @property
+ def pass_key(self) -> str:
+ """Inductor config key for the PassManager custom pass"""
+ return "post_grad_custom_post_pass"
+
@property
def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform."""
@@ -177,6 +182,21 @@ class Platform:
# all ROCm platforms for now.
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
+ @classmethod
+ def get_pass_manager_cls(cls) -> str:
+ """
+ Get the pass manager class for this platform.
+ It will be registered as a custom pass under the current_platform.pass_key.
+ """
+ return "vllm.compilation.pass_manager.PostGradPassManager"
+
+ @classmethod
+ def get_compile_backend(cls) -> str:
+ """
+ Get the custom compile backend for current platform.
+ """
+ return cls.simple_compile_backend
+
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):
# Treat empty device control env var as unset. This is a valid
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index f9005fd7d044c..0483f6c06ada8 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -264,28 +264,66 @@ class RocmPlatform(Platform):
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
- return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
- if (
- rocm_aiter_ops.is_mha_enabled()
- ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
- logger.info("Using Aiter Flash Attention backend.")
- return AttentionBackendEnum.ROCM_AITER_FA.get_path()
- if (
- rocm_aiter_ops.is_triton_unified_attn_enabled()
- ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
- logger.info("Using Aiter Unified Attention backend.")
- return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
- if (
- envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
- or selected_backend == AttentionBackendEnum.ROCM_ATTN
- ):
- # rocm specific backend, with aiter and/or
- # triton prefix-prefill
- logger.info("Using Rocm Attention backend.")
+ return AttentionBackendEnum.FLEX_ATTENTION.get_path()
+
+ if selected_backend == AttentionBackendEnum.TRITON_ATTN:
+ logger.info("Using Triton Attention backend on V1 engine.")
+ return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ if selected_backend == AttentionBackendEnum.ROCM_ATTN:
+ logger.info("Using Rocm Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
- # default case, using triton unified attention
- logger.info("Using Triton Attention backend.")
- return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
+ if on_gfx9():
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+ else:
+ raise ValueError(
+ f"The selected backend, {selected_backend.name}, "
+ "is only supported on gfx9 architectures."
+ )
+
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
+ logger.info("Using Aiter Unified Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
+
+ # Handle automatic backend selection based on environment variables
+ if selected_backend is None:
+ # Priority 1: Check for AITER Unified Attention (must check before MHA)
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
+ logger.info("Using Aiter Unified Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
+
+ # Priority 2: Check for AITER MHA (Flash Attention)
+ # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+
+ # Priority 3: Check for ROCM_ATTN (prefill-decode split)
+ if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
+ logger.info("Using Rocm Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_ATTN.get_path()
+
+ # Priority 4: Check for AITER enabled without specific flags
+ # This defaults to AITER FA only if MHA is not explicitly disabled
+ if (
+ envs.VLLM_ROCM_USE_AITER
+ and on_gfx9()
+ and envs.VLLM_ROCM_USE_AITER_MHA is not False
+ ):
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+
+ # Default: Triton Unified Attention
+ logger.info("Using Triton Attention backend on V1 engine.")
+ return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ raise RuntimeError(
+ f"Attention backend {selected_backend.name} is not supported on "
+ "ROCm. Note that V0 attention backends have been removed."
+ )
@classmethod
def set_device(cls, device: torch.device) -> None:
diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py
index 5c3dfa8ac9cbc..d1aab98c274e1 100644
--- a/vllm/pooling_params.py
+++ b/vllm/pooling_params.py
@@ -57,7 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
- skip_reading_prefix_cache: bool = None
+ skip_reading_prefix_cache: bool | None = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index fbbe3d4cabb9a..8de961e62db1b 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -3,7 +3,6 @@
"""Sampling parameters for text generation."""
import copy
-import warnings
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
@@ -100,19 +99,6 @@ class StructuredOutputsParams:
)
-@dataclass
-class GuidedDecodingParams(StructuredOutputsParams):
- def __post_init__(self):
- warnings.warn(
- "GuidedDecodingParams is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "StructuredOutputsParams instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- return super().__post_init__()
-
-
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@@ -234,8 +220,6 @@ class SamplingParams(
# Fields used to construct logits processors
structured_outputs: StructuredOutputsParams | None = None
"""Parameters for configuring structured outputs."""
- guided_decoding: GuidedDecodingParams | None = None
- """Deprecated alias for structured_outputs."""
logit_bias: dict[int, float] | None = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
@@ -254,7 +238,7 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
- skip_reading_prefix_cache: bool = None
+ skip_reading_prefix_cache: bool | None = None
@staticmethod
def from_optional(
@@ -283,7 +267,6 @@ class SamplingParams(
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None,
- guided_decoding: GuidedDecodingParams | None = None,
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
@@ -295,16 +278,6 @@ class SamplingParams(
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
- if guided_decoding is not None:
- warnings.warn(
- "guided_decoding is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "structured_outputs instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- structured_outputs = guided_decoding
- guided_decoding = None
return SamplingParams(
n=1 if n is None else n,
@@ -387,17 +360,6 @@ class SamplingParams(
# eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids)
- if self.guided_decoding is not None:
- warnings.warn(
- "guided_decoding is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "structured_outputs instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- self.structured_outputs = self.guided_decoding
- self.guided_decoding = None
-
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 9eac7bb50afa6..66680f410cb3c 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -27,7 +27,7 @@ from huggingface_hub.utils import (
RevisionNotFoundError,
)
from packaging.version import Version
-from transformers import DeepseekV3Config, GenerationConfig, PretrainedConfig
+from transformers import GenerationConfig, PretrainedConfig
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import (
@@ -42,7 +42,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import (
check_gguf_file,
+ is_gguf,
+ is_remote_gguf,
parse_safetensors_file_metadata,
+ split_remote_gguf,
)
if envs.VLLM_USE_MODELSCOPE:
@@ -84,8 +87,9 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
afmoe="AfmoeConfig",
chatglm="ChatGLMConfig",
deepseek_vl_v2="DeepseekVLV2Config",
- deepseek_v32=DeepseekV3Config,
+ deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
+ hunyuan_vl="HunYuanVLConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
@@ -204,7 +208,19 @@ class MistralConfigParser(ConfigParserBase):
from vllm.transformers_utils.configs.mistral import adapt_config_dict
- config = adapt_config_dict(config_dict)
+ # Get missing fields from HF config if available
+ try:
+ hf_config_dict, _ = PretrainedConfig.get_config_dict(
+ model,
+ revision=revision,
+ code_revision=code_revision,
+ token=_get_hf_token(),
+ **kwargs,
+ )
+ except OSError: # Not found
+ hf_config_dict = {}
+
+ config = adapt_config_dict(config_dict, defaults=hf_config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
@@ -440,51 +456,55 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No
def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
- # Retrieve rope_parameters differently based on Transformers version
+ # Patch rope_parameters differently based on Transformers version
if Version(version("transformers")) >= Version("5.0.0.dev0"):
- from transformers.modeling_rope_utils import RopeParameters
-
- rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
- config, "rope_parameters", None
+ from transformers.modeling_rope_utils import (
+ rope_config_validation,
+ standardize_rope_params,
)
- elif hasattr(config, "rope_parameters"):
- # We are in Transformers v4 and rope_parameters
- # has already been patched for this config
- return
+
+ # When Transformers v5 is installed, legacy rope_theta may be present
+ # when using custom code models written for Transformers v4
+ if (rope_theta := getattr(config, "rope_theta", None)) is not None:
+ standardize_rope_params(config, rope_theta=rope_theta)
+ rope_config_validation(config)
+ # Delete rope_theta to avoid confusion in downstream code
+ del config.rope_theta
else:
- # Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
- rope_theta: float | None = getattr(config, "rope_theta", None)
- rope_scaling: dict | None = getattr(config, "rope_scaling", None)
- rope_parameters = rope_scaling
- # Move rope_theta into rope_parameters
- if rope_theta is not None:
- rope_parameters = rope_parameters or {"rope_type": "default"}
- rope_parameters["rope_theta"] = rope_theta
- # Add original_max_position_embeddings if present
- if rope_parameters and (
- ompe := getattr(config, "original_max_position_embeddings", None)
- ):
- rope_parameters["original_max_position_embeddings"] = ompe
- # Write back to config
- config.rope_parameters = rope_parameters
+ # When Transformers v4 is installed, legacy rope_scaling may be present
+ if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
+ config.rope_parameters = rope_scaling
+ # When Transformers v4 is installed, legacy rope_theta may be present
+ if (rope_theta := getattr(config, "rope_theta", None)) is not None:
+ if not hasattr(config, "rope_parameters"):
+ config.rope_parameters = {"rope_type": "default"}
+ config.rope_parameters["rope_theta"] = rope_theta
# No RoPE parameters to patch
- if rope_parameters is None:
+ if not hasattr(config, "rope_parameters"):
return
+ # Add original_max_position_embeddings if present
+ if ompe := getattr(config, "original_max_position_embeddings", None):
+ config.rope_parameters["original_max_position_embeddings"] = ompe
+
# Handle nested rope_parameters in interleaved sliding attention models
- if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
- for rope_parameters_layer_type in rope_parameters.values():
+ if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
+ for rope_parameters_layer_type in config.rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type)
else:
- patch_rope_parameters_dict(rope_parameters)
+ patch_rope_parameters_dict(config.rope_parameters)
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
if "rope_type" in rope_parameters and "type" in rope_parameters:
rope_type = rope_parameters["rope_type"]
rope_type_legacy = rope_parameters["type"]
- if rope_type != rope_type_legacy:
+ if (rope_type_legacy == "su" and rope_type == "longrope") or (
+ rope_type_legacy == "mrope" and rope_type == "default"
+ ):
+ pass # No action needed
+ elif rope_type != rope_type_legacy:
raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern "
f"field) and 'type={rope_type_legacy}' (legacy field). "
@@ -537,6 +557,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
return uses_mrope(thinker_text_config)
+def uses_xdrope_dim(config: PretrainedConfig) -> int:
+ """Detect if the model with this config uses XD-ROPE."""
+ xdrope_section = getattr(config, "xdrope_section", None)
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling is None:
+ return 0
+
+ if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
+ xdrope_section = rope_scaling["xdrope_section"]
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+
+ return 0
+
+
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
@@ -599,10 +636,12 @@ def maybe_override_with_speculators(
Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
"""
- is_gguf = check_gguf_file(model)
- if is_gguf:
+ if check_gguf_file(model):
kwargs["gguf_file"] = Path(model).name
gguf_model_repo = Path(model).parent
+ elif is_remote_gguf(model):
+ repo_id, _ = split_remote_gguf(model)
+ gguf_model_repo = Path(repo_id)
else:
gguf_model_repo = None
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
@@ -648,10 +687,18 @@ def get_config(
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
- is_gguf = check_gguf_file(model)
- if is_gguf:
- kwargs["gguf_file"] = Path(model).name
- model = Path(model).parent
+ _is_gguf = is_gguf(model)
+ _is_remote_gguf = is_remote_gguf(model)
+ if _is_gguf:
+ if check_gguf_file(model):
+ # Local GGUF file
+ kwargs["gguf_file"] = Path(model).name
+ model = Path(model).parent
+ elif _is_remote_gguf:
+ # Remote GGUF - extract repo_id from repo_id:quant_type format
+ # The actual GGUF file will be downloaded later by GGUFModelLoader
+ # Keep model as repo_id:quant_type for download, but use repo_id for config
+ model, _ = split_remote_gguf(model)
if config_format == "auto":
try:
@@ -659,10 +706,25 @@ def get_config(
# Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral"
- elif is_gguf or file_or_path_exists(
+ elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
config_format = "hf"
+ # Remote GGUF models must have config.json in repo,
+ # otherwise the config can't be parsed correctly.
+ # FIXME(Isotr0py): Support remote GGUF repos without config.json
+ elif _is_remote_gguf and not file_or_path_exists(
+ model, HF_CONFIG_NAME, revision=revision
+ ):
+ err_msg = (
+ "Could not find config.json for remote GGUF model repo. "
+ "To load remote GGUF model through `:`, "
+ "ensure your model has config.json (HF format) file. "
+ "Otherwise please specify --hf-config-path "
+ "in engine args to fetch config from unquantized hf model."
+ )
+ logger.error(err_msg)
+ raise ValueError(err_msg)
else:
raise ValueError(
"Could not detect config format for no config file found. "
@@ -683,9 +745,6 @@ def get_config(
"'config.json'.\n"
" - For Mistral models: ensure the presence of a "
"'params.json'.\n"
- "3. For GGUF: pass the local path of the GGUF checkpoint.\n"
- " Loading GGUF from a remote repo directly is not yet "
- "supported.\n"
).format(model=model)
raise ValueError(error_message) from e
@@ -699,7 +758,7 @@ def get_config(
**kwargs,
)
# Special architecture mapping check for GGUF models
- if is_gguf:
+ if _is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
@@ -859,6 +918,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
A dictionary containing the pooling type and whether
normalization is used, or None if no pooling configuration is found.
"""
+ if is_remote_gguf(model):
+ model, _ = split_remote_gguf(model)
modules_file_name = "modules.json"
@@ -1078,6 +1139,8 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if check_gguf_file(model):
model = Path(model).parent
+ elif is_remote_gguf(model):
+ model, _ = split_remote_gguf(model)
return get_image_processor_config(
model, token=hf_token, revision=revision, **kwargs
)
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index dcae05a15fec3..109f2b6986514 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -5,8 +5,13 @@ Model configs may be defined in this directory for the following reasons:
- There is no configuration file defined by HF Hub or Transformers library.
- There is a need to override the existing config to support vLLM.
+- The HF model_type isn't recognized by the Transformers library but can
+ be mapped to an existing Transformers config, such as
+ deepseek-ai/DeepSeek-V3.2-Exp.
"""
+from transformers import DeepseekV3Config
+
from vllm.transformers_utils.configs.afmoe import AfmoeConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
@@ -18,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLTextConfig,
+ HunYuanVLVisionConfig,
+)
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
@@ -44,9 +54,13 @@ __all__ = [
"AfmoeConfig",
"ChatGLMConfig",
"DeepseekVLV2Config",
+ "DeepseekV3Config",
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
+ "HunYuanVLConfig",
+ "HunYuanVLTextConfig",
+ "HunYuanVLVisionConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",
diff --git a/vllm/transformers_utils/configs/hunyuan_vl.py b/vllm/transformers_utils/configs/hunyuan_vl.py
new file mode 100644
index 0000000000000..a826ed9b5155d
--- /dev/null
+++ b/vllm/transformers_utils/configs/hunyuan_vl.py
@@ -0,0 +1,322 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
+
+from transformers import PretrainedConfig
+
+
+class HunYuanVLVisionConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_act="gelu",
+ hidden_size=1152,
+ intermediate_size=4304,
+ interpolate_mode="bilinear",
+ rms_norm_eps=1e-05,
+ learnable_mlp_pooling_size=0,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_channels=3,
+ num_hidden_layers=27,
+ out_hidden_size=4096,
+ patch_size=16,
+ remove_prenorm=True,
+ spatial_merge_size=2,
+ temporal_patch_size=1,
+ resize_resolution=2048,
+ img_max_token_num=4096,
+ max_image_size=2048,
+ video_max_image_size=768,
+ video_min_image_size=256,
+ min_image_size=512,
+ anyres_vit_max_image_size=2048,
+ max_vit_seq_len=16384,
+ text_hidden_size=3072,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_act = hidden_act
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.interpolate_mode = interpolate_mode
+ self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
+ self.num_attention_heads = num_attention_heads
+ if not num_key_value_heads:
+ self.num_key_value_heads = num_attention_heads
+ else:
+ self.num_key_value_heads = num_key_value_heads
+ self.num_channels = num_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.out_hidden_size = out_hidden_size
+ self.patch_size = patch_size
+ self.remove_prenorm = remove_prenorm
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.rms_norm_eps = rms_norm_eps
+
+ self.resize_resolution = resize_resolution
+ self.img_max_token_num = img_max_token_num
+ self.max_image_size = max_image_size
+ self.min_image_size = min_image_size
+ self.video_max_image_size = video_max_image_size
+ self.video_min_image_size = video_min_image_size
+ self.anyres_vit_max_image_size = anyres_vit_max_image_size
+ self.max_vit_seq_len = max_vit_seq_len
+ self.text_hidden_size = text_hidden_size
+
+
+class HunYuanVLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
+ Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 290943):
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HunYuanVLTextConfig`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations or shared MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ eod_token_id (int, *optional*, defaults to 3):
+ Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
+ Example: In multi-document processing, this token helps the model distinguish between separate documents.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ """ # noqa: E501
+
+ model_type = "hunyuan_vl_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=290943,
+ hidden_size=4096,
+ intermediate_size: int = 11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ eod_token_id=3,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # self._rope_scaling_validation() # TODO: Need validation?
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and "
+ f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ "`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
+ f"got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
+ raise ValueError(
+ "`rope_scaling`'s factor or alpha field must be have one, "
+ "got both of none"
+ )
+ if rope_scaling_factor is not None and (
+ not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s factor field must be a float > 1.0, "
+ f"got {rope_scaling_factor}"
+ )
+ if rope_scaling_alpha is not None and (
+ not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s alpha field must be a float > 1.0, "
+ f"got {rope_scaling_alpha}"
+ )
+
+
+class HunYuanVLConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ sub_configs = {
+ "vision_config": HunYuanVLVisionConfig,
+ "text_config": HunYuanVLTextConfig,
+ }
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ im_start_id=120118,
+ im_end_id=120119,
+ image_token_id=120120,
+ im_newline_id=120121,
+ video_start_id=120122,
+ video_end_id=120123,
+ **kwargs,
+ ):
+ # We need to init super() here so that it does not reset values
+ # that are in text config to the BaseClass defaults. The Base
+ # config has many text related defaults and not all defaults are
+ # same as for `HunYuanVLTextConfig`.
+ super().__init__(**kwargs)
+
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ # For BC use all kwargs to init `TextConfig`
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.im_start_id = im_start_id
+ self.im_end_id = im_end_id
+ self.im_newline_id = im_newline_id
+ self.video_start_id = video_start_id
+ self.video_end_id = video_end_id
+
+ self.vision_config.text_hidden_size = self.text_config.hidden_size
+
+ # Attention implementation to use. It sets it recursively on sub-configs
+ # so we call it again in the end.
+ self._attn_implementation = kwargs.pop("attn_implementation", None)
+
+ def __setattr__(self, key, value):
+ if (
+ (text_config := super().__getattribute__("__dict__").get("text_config"))
+ is not None
+ and key not in ["dtype", "_attn_implementation_internal"]
+ and key in text_config.__dict__
+ ):
+ setattr(text_config, key, value)
+ else:
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if "text_config" in super().__getattribute__("__dict__") and key not in [
+ "_name_or_path",
+ "model_type",
+ "dtype",
+ "_attn_implementation_internal",
+ ]:
+ text_config = super().__getattribute__("text_config")
+ if key in text_config.__dict__:
+ return getattr(text_config, key)
+
+ return super().__getattribute__(key)
diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py
index 8da4ab35c56c3..966737aad0867 100644
--- a/vllm/transformers_utils/configs/mistral.py
+++ b/vllm/transformers_utils/configs/mistral.py
@@ -9,14 +9,18 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
-def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
- config_dict.update(kwargs)
+def adapt_config_dict(
+ config_dict: dict[str, Any],
+ defaults: dict[str, Any],
+) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
- if bool(config_dict.get("moe")):
+ if config_dict.get("model_type") == "mamba":
+ config_dict["architectures"] = ["Mamba2ForCausalLM"]
+ elif bool(config_dict.get("moe")):
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
@@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
+ for k, v in defaults.items():
+ config_dict.setdefault(k, v)
+
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config %s", config)
diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py
index 8deacb5b07913..63cdf63370342 100644
--- a/vllm/transformers_utils/processor.py
+++ b/vllm/transformers_utils/processor.py
@@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
-from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
+from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
@@ -236,8 +236,8 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
- if check_gguf_file(model_config.model):
- assert not check_gguf_file(model_config.tokenizer), (
+ if is_gguf(model_config.model):
+ assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor."
)
@@ -350,8 +350,8 @@ def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
- if check_gguf_file(model_config.model):
- assert not check_gguf_file(model_config.tokenizer), (
+ if is_gguf(model_config.model):
+ assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor."
)
diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py
index 76b6d3dc9c99a..b49fdbe9ce776 100644
--- a/vllm/transformers_utils/processors/__init__.py
+++ b/vllm/transformers_utils/processors/__init__.py
@@ -9,7 +9,15 @@ reasons:
"""
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
-__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
+__all__ = [
+ "DeepseekVLV2Processor",
+ "HunYuanVLProcessor",
+ "HunYuanVLImageProcessor",
+ "OvisProcessor",
+ "Ovis2_5Processor",
+]
diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py
new file mode 100644
index 0000000000000..615a8bff85912
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl.py
@@ -0,0 +1,233 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
+
+import numpy as np
+import torch
+from transformers import AutoProcessor
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.video_utils import VideoInput
+
+
+class HunYuanVLProcessor(ProcessorMixin):
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ video_processor=None,
+ chat_template=None,
+ **kwargs,
+ ):
+ # TODO Fix the init
+ self.tokenizer = tokenizer
+ self.image_token_id = 120120 # self.tokenizer.image_token_id
+ self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
+ self.im_start_token_id = 120118 # self.tokenizer.im_start_id
+ self.im_start_token = self.tokenizer.convert_ids_to_tokens(
+ self.im_start_token_id
+ )
+ self.im_end_token_id = 120119 # self.tokenizer.im_end_id
+ self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
+ self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
+ self.tokenizer.vocab_size - 1
+ )
+ self.pad_id = 120002 # self.tokenizer.pad_token_id
+
+ super().__init__(
+ image_processor, tokenizer, video_processor, chat_template=chat_template
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: TextInput
+ | PreTokenizedInput
+ | list[TextInput]
+ | list[PreTokenizedInput] = None,
+ videos: VideoInput = None,
+ **kwargs,
+ ) -> BatchFeature:
+ image_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images=images)
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+
+ image_tokens_cumsum = [0]
+ if images is not None:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ grid_h, grid_w = image_grid_thw[index][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ num_image_tokens = patch_h * (patch_w + 1) + 2
+ image_tokens_cumsum.append(
+ image_tokens_cumsum[-1] + num_image_tokens
+ )
+ # text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
+ text[i] = text[i].replace(
+ self.image_token, self.placeholder_token * num_image_tokens, 1
+ )
+ index += 1
+ text[i] = text[i].replace(self.placeholder_token, self.image_token)
+ # text[i] = self.tokenizer.bos_token + text[i]
+
+ text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ input_ids = text_inputs["input_ids"]
+ position_ids = torch.arange(len(input_ids[0]))
+ position_ids_w = torch.arange(len(input_ids[0]))
+ position_ids_h = torch.arange(len(input_ids[0]))
+ position_ids_t = torch.arange(len(input_ids[0]))
+
+ if images is not None:
+ image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
+ 0
+ ]
+ for i in range(len(image_grid_thw)):
+ grid_h, grid_w = image_grid_thw[i][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
+ replace_num = (patch_w + 1) * patch_h
+ position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
+ list(range(patch_w + 1)) * patch_h, dtype=torch.int64
+ )
+ patch_h_list = []
+ for h in range(patch_h):
+ patch_h_list += [h] * (patch_w + 1)
+ position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
+ patch_h_list, dtype=torch.int64
+ )
+ position_ids_t[start_pos : start_pos + replace_num] = 0
+
+ position_ids = torch.stack(
+ [position_ids, position_ids_w, position_ids_h, position_ids_t]
+ ).unsqueeze(0)
+ text_inputs["position_ids"] = position_ids
+
+ attention_mask = input_ids.ne(self.pad_id)
+ text_inputs["attention_mask"] = attention_mask
+ text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
+ # image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
+
+ return_tensors = kwargs.pop("return_tensors", None)
+ return BatchFeature(
+ data={**text_inputs, **image_inputs},
+ tensor_type=return_tensors,
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_image_text_to_text(
+ self,
+ generated_outputs,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ assert 0
+
+ def apply_chat_template(self, *args, **kwargs):
+ token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
+ return token_ids
+
+ def get_imgs_pos(self, doc_ids):
+ doc_ids = np.array(doc_ids, dtype=np.int64)
+ img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
+ img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
+ imgs_pos = np.concatenate(
+ (
+ np.reshape(img_begin_index + 1, (-1, 1)),
+ np.reshape(img_end_index, (-1, 1)),
+ ),
+ axis=-1,
+ ).tolist()
+ return imgs_pos
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+def split_image_into_patch_blocks(
+ pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
+ patch_size: int = 16, # e.g. 16
+ adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
+) -> torch.Tensor:
+ """
+ Split the input image tensor (supporting batch) into large patches of size `patch_size`,
+ and then further divide each large patch into smaller regions of size
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
+ Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
+ The final output contains all such small region tensors.
+
+ Args:
+ pixel_values: Input image tensor of shape [batch_size, 3, H, W].
+ patch_size: Size of the large patch, e.g., 16.
+ adaptor_patch_div: Each large patch is divided into
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
+ smaller regions.
+
+ Returns:
+ patches: A tensor of shape [N, 3, patch_size, patch_size],
+ where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
+ Each element in the batch corresponds to one small image region.
+ """ # noqa: E501
+ batch_size, channels, height, width = pixel_values.shape
+ assert channels == 3, "Pixel values must have 3 channels in dim=1"
+ assert height % patch_size == 0 and width % patch_size == 0, (
+ "H and W must be divisible by patch_size"
+ )
+
+ patch_height_num = height // patch_size
+ patch_width_num = width // patch_size
+
+ # Reshape to [B, 3, ph, ps, pw, ps]
+ img = pixel_values.reshape(
+ batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
+ )
+
+ # Further split each psxps patch into (ps//aps)x(ps//aps) small regions
+ img = img.reshape(
+ batch_size,
+ 3,
+ patch_height_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ patch_width_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ )
+
+ # Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
+ img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
+
+ # Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
+ patches = img.reshape(-1, 3, patch_size, patch_size)
+
+ return patches
+
+
+AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)
diff --git a/vllm/transformers_utils/processors/hunyuan_vl_image.py b/vllm/transformers_utils/processors/hunyuan_vl_image.py
new file mode 100644
index 0000000000000..0a7e7865c783a
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl_image.py
@@ -0,0 +1,477 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
+"""Image processor class for HunYuanVL."""
+
+# isort conflicts with ruff for transformers imports
+# isort: skip_file
+import math
+
+import numpy as np
+import torchvision.transforms as transforms
+from transformers import AutoImageProcessor
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import (
+ convert_to_rgb,
+)
+from transformers.image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ make_flat_list_of_images,
+ make_list_of_images,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from transformers.utils import TensorType, logging
+from transformers.video_utils import VideoInput, make_batched_videos
+
+logger = logging.get_logger(__name__)
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ "absolute aspect ratio must be smaller than 200, got "
+ f"{max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class HunYuanVLImageProcessor(BaseImageProcessor):
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: int | float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and (
+ "shortest_edge" not in size or "longest_edge" not in size
+ ):
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ else:
+ size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ # hard-code
+
+ def _preprocess(
+ self,
+ images: ImageInput | VideoInput,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ do_convert_rgb: bool | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """ # noqa: E501
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ width, height = images[0].width, images[0].height
+ resized_width, resized_height = width, height
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_width, resized_height = smart_resize(
+ width,
+ height,
+ factor=patch_size * merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ if do_normalize:
+ image = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(self.image_mean, self.image_std),
+ ]
+ )(image)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ patches = patches.reshape(
+ 1,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
+ flatten_patches = patches.reshape(
+ 1 * grid_h * grid_w, channel * patch_size * patch_size
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int | None = None,
+ temporal_patch_size: int | None = None,
+ merge_size: int | None = None,
+ do_convert_rgb: bool | None = None,
+ return_tensors: str | TensorType | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ videos (`VideoInput`):
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """ # noqa: E501
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ min_pixels = size["shortest_edge"]
+ elif min_pixels is not None and max_pixels is not None:
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ else:
+ size = {**self.size}
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = (
+ rescale_factor if rescale_factor is not None else self.rescale_factor
+ )
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = (
+ temporal_patch_size
+ if temporal_patch_size is not None
+ else self.temporal_patch_size
+ )
+ merge_size = merge_size if merge_size is not None else self.merge_size
+ do_convert_rgb = (
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ )
+
+ if images is not None:
+ images = make_flat_list_of_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ data = {}
+ if images is not None:
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data.update(
+ {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
+ )
+
+ # kept for BC only and should be removed after v5.0
+ if videos is not None:
+ logger.warning(
+ "`HunYuanVLV1ImageProcessor` works only with image inputs "
+ "and doesn't process videos anymore. "
+ "This is a deprecated behavior and will be removed in v5.0. "
+ "Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
+ )
+ videos = make_batched_videos(videos)
+ pixel_values_videos, vision_grid_thws_videos = [], []
+ for images in videos:
+ patches, video_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values_videos.extend(patches)
+ vision_grid_thws_videos.append(video_grid_thw)
+ data.update(
+ {
+ "pixel_values_videos": np.array(pixel_values_videos),
+ "video_grid_thw": np.array(vision_grid_thws_videos),
+ }
+ )
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*):
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = (
+ images_kwargs["min_pixels"]
+ if "min_pixels" in images_kwargs
+ else self.size["shortest_edge"]
+ )
+ max_pixels = (
+ images_kwargs["max_pixels"]
+ if "max_pixels" in images_kwargs
+ else self.size["longest_edge"]
+ )
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * (grid_w + 1) + 2
+
+
+AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index 233076741503d..f0e0ba8ef4246 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -20,7 +20,12 @@ from vllm.transformers_utils.config import (
list_filtered_repo_files,
)
from vllm.transformers_utils.tokenizers import MistralTokenizer
-from vllm.transformers_utils.utils import check_gguf_file
+from vllm.transformers_utils.utils import (
+ check_gguf_file,
+ is_gguf,
+ is_remote_gguf,
+ split_remote_gguf,
+)
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -180,10 +185,12 @@ def get_tokenizer(
kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models
- is_gguf = check_gguf_file(tokenizer_name)
- if is_gguf:
- kwargs["gguf_file"] = Path(tokenizer_name).name
- tokenizer_name = Path(tokenizer_name).parent
+ if is_gguf(tokenizer_name):
+ if check_gguf_file(tokenizer_name):
+ kwargs["gguf_file"] = Path(tokenizer_name).name
+ tokenizer_name = Path(tokenizer_name).parent
+ elif is_remote_gguf(tokenizer_name):
+ tokenizer_name, _ = split_remote_gguf(tokenizer_name)
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
# first to use official Mistral tokenizer if possible.
diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py
index 901a64d9d2633..45a873c9f7001 100644
--- a/vllm/transformers_utils/utils.py
+++ b/vllm/transformers_utils/utils.py
@@ -9,6 +9,8 @@ from os import PathLike
from pathlib import Path
from typing import Any
+from gguf import GGMLQuantizationType
+
import vllm.envs as envs
from vllm.logger import init_logger
@@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
return False
+@cache
+def is_remote_gguf(model: str | Path) -> bool:
+ """Check if the model is a remote GGUF model."""
+ model = str(model)
+ return (
+ (not is_cloud_storage(model))
+ and (not model.startswith(("http://", "https://")))
+ and ("/" in model and ":" in model)
+ and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
+ )
+
+
+def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
+ """Check if the quant type is a valid GGUF quant type."""
+ return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
+
+
+def split_remote_gguf(model: str | Path) -> tuple[str, str]:
+ """Split the model into repo_id and quant type."""
+ model = str(model)
+ if is_remote_gguf(model):
+ parts = model.rsplit(":", 1)
+ return (parts[0], parts[1])
+ raise ValueError(
+ "Wrong GGUF model or invalid GGUF quant type: %s.\n"
+ "- It should be in repo_id:quant_type format.\n"
+ "- Valid GGMLQuantizationType values: %s",
+ model,
+ GGMLQuantizationType._member_names_,
+ )
+
+
+def is_gguf(model: str | Path) -> bool:
+ """Check if the model is a GGUF model.
+
+ Args:
+ model: Model name, path, or Path object to check.
+
+ Returns:
+ True if the model is a GGUF model, False otherwise.
+ """
+ model = str(model)
+
+ # Check if it's a local GGUF file
+ if check_gguf_file(model):
+ return True
+
+ # Check if it's a remote GGUF model (repo_id:quant_type format)
+ return is_remote_gguf(model)
+
+
def modelscope_list_repo_files(
repo_id: str,
revision: str | None = None,
diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py
index 3ef44e7703204..fddcc27204307 100644
--- a/vllm/utils/__init__.py
+++ b/vllm/utils/__init__.py
@@ -49,13 +49,14 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
-STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
+MASK_64_BITS = (1 << 64) - 1
+
def random_uuid() -> str:
- return str(uuid.uuid4().hex)
+ return f"{uuid.uuid4().int & MASK_64_BITS:016x}" # 16 hex chars
def length_from_prompt_token_ids_or_embeds(
diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py
index 3d105a3685b37..692e756d19634 100644
--- a/vllm/utils/argparse_utils.py
+++ b/vllm/utils/argparse_utils.py
@@ -73,14 +73,6 @@ class FlexibleArgumentParser(ArgumentParser):
# Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None):
- if args is not None and "--disable-log-requests" in args:
- # Special case warning because the warning below won't trigger
- # if –-disable-log-requests because its value is default.
- logger.warning_once(
- "argument '--disable-log-requests' is deprecated and "
- "replaced with '--enable-log-requests'. This will be "
- "removed in v0.12.0."
- )
namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated:
if (
diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py
index f1254352c0585..590bf91b0d057 100644
--- a/vllm/v1/attention/backends/cpu_attn.py
+++ b/vllm/v1/attention/backends/cpu_attn.py
@@ -25,7 +25,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
-_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86,)
+_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM)
class CPUAttentionBackend(AttentionBackend):
@@ -491,6 +491,9 @@ def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
elif block_size % 32 == 0:
- return "vec"
+ if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
+ return "neon"
+ else:
+ return "vec"
else:
return "vec16"
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index 9fa6b1dfd19dd..a9a4af5ac1183 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
-from vllm.config import VllmConfig, get_layers_from_vllm_config
+from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
@@ -56,11 +56,26 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- # NOTE(tdoublep): while in principle, FA supports
- # MultipleOf(16), these are the block sizes that do not
- # suffer from the NaN propagation problem described here:
- # https://github.com/Dao-AILab/flash-attention/issues/1974
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
+
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ vllm_config = get_current_vllm_config()
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ if (
+ model_config
+ and model_config.is_hybrid
+ and (
+ cache_config.mamba_ssm_cache_dtype == "float32"
+ or cache_config.mamba_cache_dtype == "float32"
+ )
+ ):
+ # NOTE(tdoublep): while in principle, FA supports
+ # MultipleOf(16), these are the block sizes that do not
+ # suffer from the NaN propagation problem described here:
+ # https://github.com/Dao-AILab/flash-attention/issues/1974
+ return [16, 32, 64]
+ return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index e3f499216d7f1..8159f4096107f 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -16,7 +16,6 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
-from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
@@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- # Note: Not sure for all platforms,
- # but on Blackwell, only support a page size of
- # 16, 32, 64
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
@@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
"fp8_e5m2",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ # Note: Not sure for all platforms, but on Blackwell,
+ # only support a page size of 16, 32, 64.
+ return [16, 32, 64]
+
@staticmethod
def get_name() -> str:
return "FLASHINFER"
@@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
@classmethod
- @override
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,
diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py
index 1900c50849eca..004baa2d09cde 100644
--- a/vllm/v1/attention/backends/linear_attn.py
+++ b/vllm/v1/attention/backends/linear_attn.py
@@ -7,6 +7,7 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
+ AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
@@ -35,6 +36,8 @@ class LinearAttentionMetadata:
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
+ _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
+
def __init__(
self,
kv_cache_spec: AttentionSpec,
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index 43aef8a7cca91..87a3aac21d2c3 100755
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -340,6 +340,8 @@ class MLACommonPrefillMetadata:
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
+ token_to_seq: torch.Tensor
+ chunk_total_token: list[int]
# for mla DCP
padded_local_chunk_seq_lens: list[list[int]] | None = None
@@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
+ chunk_total_token = cu_seq_lens_cpu[:, -1]
+
+ max_token_num_over_chunk = chunk_total_token.max().item()
+ token_to_seq_tensor_cpu = torch.zeros(
+ [num_chunks, max_token_num_over_chunk], dtype=torch.int32
+ )
+ range_idx = torch.arange(num_prefills, dtype=torch.int32)
+ for i in range(num_chunks):
+ chunk_token_to_seq_tensor = torch.repeat_interleave(
+ range_idx, chunk_seq_lens[i]
+ )
+ chunk_len = chunk_token_to_seq_tensor.shape[0]
+ token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
if self.dcp_world_size > 1:
local_context_lens_allranks = get_dcp_local_seq_lens(
@@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
+ token_to_seq=token_to_seq_tensor_cpu.to(
+ device, non_blocking=True
+ ),
+ chunk_total_token=chunk_total_token.tolist(),
workspace=self.chunked_prefill_workspace,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
@@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
+ token_to_seq=token_to_seq_tensor_cpu.to(
+ device, non_blocking=True
+ ),
+ chunk_total_token=chunk_total_token,
workspace=self.chunked_prefill_workspace,
)
@@ -1638,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
-
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
-
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
- batch_size=attn_metadata.num_prefills,
+ token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
+ num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py
index 60cb5022a55eb..5e3fbc0abf083 100644
--- a/vllm/v1/attention/backends/mla/cutlass_mla.py
+++ b/vllm/v1/attention/backends/mla/cutlass_mla.py
@@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [128]
+
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py
index 12639edc8b9a1..d369814c10b6f 100644
--- a/vllm/v1/attention/backends/mla/flashattn_mla.py
+++ b/vllm/v1/attention/backends/mla/flashattn_mla.py
@@ -41,9 +41,12 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
+
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py
index 52bb19e039e45..f02a4bb1ef35a 100644
--- a/vllm/v1/attention/backends/mla/flashinfer_mla.py
+++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py
@@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [32, 64]
+
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py
index 3aab1f9bb7fb6..74a4cd8430250 100644
--- a/vllm/v1/attention/backends/mla/flashmla.py
+++ b/vllm/v1/attention/backends/mla/flashmla.py
@@ -39,13 +39,16 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [64]
+
@staticmethod
def get_name() -> str:
return "FLASHMLA"
diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py
index 3f2cc8c38327e..1eee1d225293b 100644
--- a/vllm/v1/attention/backends/mla/flashmla_sparse.py
+++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py
@@ -55,9 +55,12 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [64]
+
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"
diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py
index d38361e0fcbf8..77f1ba00d5b04 100644
--- a/vllm/v1/attention/backends/mla/indexer.py
+++ b/vllm/v1/attention/backends/mla/indexer.py
@@ -24,9 +24,9 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
- 1 if current_platform.is_rocm() else 64
- ]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [1 if current_platform.is_rocm() else 64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
index 6ccc1a341d56c..00a0a77a1c2f7 100644
--- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
@@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [1]
@staticmethod
def get_name() -> str:
@@ -47,6 +49,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
paged_kv_last_page_len: torch.Tensor | None = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None
+ # The dtype of MLA out tensor
+ attn_out_dtype: torch.dtype = torch.bfloat16
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@@ -72,6 +76,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
)
self.compilation_config = vllm_config.compilation_config
+ self.decode_attn_out_dtype = vllm_config.model_config.dtype
# kernel block size is always 1.
max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
@@ -160,6 +165,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
+ attn_out_dtype=self.decode_attn_out_dtype,
)
return attn_metadata
@@ -240,7 +246,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(
- B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
+ B,
+ self.num_heads,
+ self.kv_lora_rank,
+ dtype=attn_metadata.decode.attn_out_dtype,
+ device=q.device,
)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
@@ -258,6 +268,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
+ q_scale=layer._q_scale,
+ kv_scale=layer._k_scale,
)
return o, None
diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py
index ea611848b0e81..ea911af3d19ce 100644
--- a/vllm/v1/attention/backends/rocm_aiter_fa.py
+++ b/vllm/v1/attention/backends/rocm_aiter_fa.py
@@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
+
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
@@ -514,12 +517,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- if attn_type != AttentionType.DECODER:
+ if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "FlashAttentionImpl"
+ "Encoder self-attention is not implemented for FlashAttentionImpl"
)
def extend_forward(
@@ -675,7 +675,14 @@ class AiterFlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
+ # key and value may be None in the case of cross attention. They are
+ # calculated once based on the output from the encoder and then cached
+ # in KV cache.
+ if (
+ self.kv_sharing_target_layer_name is None
+ and key is not None
+ and value is not None
+ ):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
@@ -701,8 +708,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
# decode:extend:prefill
query = query[:num_actual_tokens]
- key = key[:num_actual_tokens]
- value = value[:num_actual_tokens]
+ if key is not None:
+ key = key[:num_actual_tokens]
+ if value is not None:
+ value = value[:num_actual_tokens]
output_actual_tokens = output[:num_actual_tokens]
diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
index b2639c0df0412..16fb52ab501c1 100644
--- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
@@ -142,7 +142,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
+ # key and value may be None in the case of cross attention. They are
+ # calculated once based on the output from the encoder and then cached
+ # in KV cache.
+ if (
+ self.kv_sharing_target_layer_name is None
+ and key is not None
+ and value is not None
+ ):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
@@ -169,7 +176,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
- descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
+ descale_shape = (
+ cu_seqlens_q.shape[0] - 1,
+ key.shape[1] if key is not None else self.num_kv_heads,
+ )
self.unified_attention(
q=query[:num_actual_tokens],
diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py
index 6dfdfc19ccba1..868143cc192e7 100644
--- a/vllm/v1/attention/backends/rocm_attn.py
+++ b/vllm/v1/attention/backends/rocm_attn.py
@@ -238,12 +238,9 @@ class RocmAttentionImpl(AttentionImpl):
RocmAttentionBackend.validate_head_size(head_size)
- if attn_type != AttentionType.DECODER:
+ if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "RocmAttentionImpl"
+ "Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py
index 1bf38ed225a4c..523f759e05a21 100644
--- a/vllm/v1/attention/backends/tree_attn.py
+++ b/vllm/v1/attention/backends/tree_attn.py
@@ -31,7 +31,10 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
+
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py
index 09c36043c8c86..d051a89f03bb4 100644
--- a/vllm/v1/attention/backends/triton_attn.py
+++ b/vllm/v1/attention/backends/triton_attn.py
@@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
@@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
"fp8_e5m2",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
+
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py
index 540a8e2b1d016..cebfe8a3ff04e 100644
--- a/vllm/v1/attention/backends/utils.py
+++ b/vllm/v1/attention/backends/utils.py
@@ -89,7 +89,8 @@ class CommonAttentionMetadata:
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
- encoder_seq_lens: np.ndarray | None = None
+ encoder_seq_lens: torch.Tensor | None = None
+ encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py
deleted file mode 100644
index d15d79417cc61..0000000000000
--- a/vllm/v1/attention/backends/xformers.py
+++ /dev/null
@@ -1,417 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Attention layer with XFormersAttention."""
-
-from dataclasses import dataclass
-from typing import ClassVar, Optional
-
-import torch
-
-from vllm.attention.backends.abstract import (
- AttentionBackend,
- AttentionImpl,
- AttentionType,
- MultipleOf,
-)
-from vllm.attention.ops.triton_unified_attention import unified_attention
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.v1.attention.backends.utils import (
- AttentionMetadataBuilder,
- CommonAttentionMetadata,
- split_decodes_and_prefills,
-)
-from vllm.v1.kv_cache_interface import AttentionSpec
-
-try:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import (
- AttentionBias,
- PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
- )
-
- XFORMERS_AVAILABLE = True
-except ImportError:
- XFORMERS_AVAILABLE = False
-
-from vllm import _custom_ops as ops
-
-logger = init_logger(__name__)
-
-
-class XFormersAttentionBackend(AttentionBackend):
- accept_output_buffer: bool = True
- supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
-
- @classmethod
- def get_supported_head_sizes(cls) -> list[int]:
- return [
- 32,
- 40,
- 48,
- 56,
- 64,
- 72,
- 80,
- 88,
- 96,
- 104,
- 112,
- 120,
- 128,
- 136,
- 144,
- 152,
- 160,
- 168,
- 176,
- 184,
- 192,
- 200,
- 208,
- 216,
- 224,
- 232,
- 240,
- 248,
- 256,
- ]
-
- @staticmethod
- def get_name() -> str:
- return "XFORMERS"
-
- @staticmethod
- def get_impl_cls() -> type["XFormersAttentionImpl"]:
- return XFormersAttentionImpl
-
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (2, num_blocks, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
- return XFormersAttentionMetadataBuilder
-
- @staticmethod
- def use_cascade_attention(*args, **kwargs) -> bool:
- return False
-
-
-@dataclass
-class XFormersAttentionMetadata:
- num_actual_tokens: int # Number of tokens excluding padding.
- max_query_len: int
- query_start_loc: torch.Tensor
- max_seq_len: int
- seq_lens: torch.Tensor
- block_table: torch.Tensor
- slot_mapping: torch.Tensor
-
- num_prefill_tokens: int = 0
- num_decode_tokens: int = 0
- num_prefills: int = 0
- num_decodes: int = 0
-
- # Biases for different attention types.
- attn_bias: Optional["AttentionBias"] = None
-
- # Self-attention prefill/decode metadata cache
- _cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
- _cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
-
- @property
- def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
- if self.num_prefills == 0:
- return None
-
- if self._cached_prefill_metadata is not None:
- # Recover cached prefill-phase attention
- # metadata structure
- return self._cached_prefill_metadata
-
- q_start_loc = self.query_start_loc[self.num_decodes :]
- q_seqlens = torch.diff(q_start_loc)
- kv_seqlens = self.seq_lens[self.num_decodes :]
- # Construct & cache prefill-phase attention metadata structure
- self._cached_prefill_metadata = XFormersAttentionMetadata(
- num_actual_tokens=self.num_prefill_tokens,
- max_query_len=int(q_seqlens.max().item()),
- query_start_loc=q_start_loc - q_start_loc[0],
- max_seq_len=int(kv_seqlens.max().item()),
- seq_lens=kv_seqlens,
- block_table=self.block_table[self.num_decodes :],
- slot_mapping=self.slot_mapping[self.num_decode_tokens :],
- )
- return self._cached_prefill_metadata
-
- @property
- def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
- if self.num_decode_tokens == 0:
- return None
-
- if self._cached_decode_metadata is not None:
- # Recover cached decode-phase attention
- # metadata structure
- return self._cached_decode_metadata
-
- q_start_loc = self.query_start_loc
- q_seqlens = torch.diff(q_start_loc)
- decode_kv_seqlens = self.seq_lens[: self.num_decodes]
- # Construct & cache decode-phase attention metadata structure
- self._cached_decode_metadata = XFormersAttentionMetadata(
- num_actual_tokens=self.num_decode_tokens,
- max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
- query_start_loc=q_start_loc[: self.num_decodes + 1],
- max_seq_len=int(decode_kv_seqlens.max().item()),
- seq_lens=decode_kv_seqlens,
- block_table=self.block_table[: self.num_decodes],
- slot_mapping=self.slot_mapping[: self.num_decode_tokens],
- attn_bias=self.attn_bias,
- )
- return self._cached_decode_metadata
-
-
-class XFormersAttentionMetadataBuilder(
- AttentionMetadataBuilder[XFormersAttentionMetadata]
-):
- reorder_batch_threshold: int = 1
-
- def __init__(
- self,
- kv_cache_spec: AttentionSpec,
- layer_names: list[str],
- vllm_config: VllmConfig,
- device: torch.device,
- ):
- super().__init__(kv_cache_spec, layer_names, vllm_config, device)
-
- assert XFORMERS_AVAILABLE
- self.block_size = kv_cache_spec.block_size
- self._num_decodes = 0
- self._num_decode_tokens = 0
-
- def build(
- self,
- common_prefix_len: int,
- common_attn_metadata: CommonAttentionMetadata,
- fast_build: bool = False,
- ) -> XFormersAttentionMetadata:
- num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
- split_decodes_and_prefills(
- common_attn_metadata, decode_threshold=self.reorder_batch_threshold
- )
- )
-
- num_actual_tokens = common_attn_metadata.num_actual_tokens
- q_start_loc = common_attn_metadata.query_start_loc
- q_seqlens = torch.diff(q_start_loc)
- max_query_len = common_attn_metadata.max_query_len
- kv_seqlens = common_attn_metadata.seq_lens
- max_seq_len = common_attn_metadata.max_seq_len
- block_table = common_attn_metadata.block_table_tensor
- slot_mapping = common_attn_metadata.slot_mapping
-
- bias = None
- if num_decodes > 0:
- # Construct the decoder bias.
- decode_q_seqlens = q_seqlens[:num_decodes]
- decode_kv_seqlens = kv_seqlens[:num_decodes]
- bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
- q_seqlen=decode_q_seqlens.tolist(),
- kv_seqlen=decode_kv_seqlens.tolist(),
- page_size=self.block_size,
- block_tables=block_table[:num_decodes],
- device=block_table.device,
- )
-
- return XFormersAttentionMetadata(
- num_actual_tokens=num_actual_tokens,
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- num_prefills=num_prefills,
- num_decodes=num_decodes,
- max_query_len=max_query_len,
- query_start_loc=q_start_loc,
- max_seq_len=max_seq_len,
- seq_lens=kv_seqlens,
- block_table=block_table,
- slot_mapping=slot_mapping,
- attn_bias=bias,
- )
-
-
-class XFormersAttentionImpl(AttentionImpl):
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: int,
- alibi_slopes: list[float] | None,
- sliding_window: int | None,
- kv_cache_dtype: str,
- logits_soft_cap: float | None = None,
- attn_type: AttentionType = AttentionType.DECODER,
- kv_sharing_target_layer_name: str | None = None,
- ) -> None:
- if kv_sharing_target_layer_name is not None:
- raise NotImplementedError("KV sharing is not supported in V0.")
- if alibi_slopes is not None:
- raise NotImplementedError("XFormers does not support alibi slopes yet.")
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_kv_heads
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- self.kv_cache_dtype = kv_cache_dtype
- self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
- if alibi_slopes is not None:
- alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
- self.alibi_slopes = alibi_slopes
- if sliding_window is None:
- self.sliding_window = (-1, -1)
- else:
- self.sliding_window = (sliding_window - 1, 0)
- if logits_soft_cap is None:
- # Setting logits_soft_cap to 0 means no soft cap.
- logits_soft_cap = 0
- self.logits_soft_cap = logits_soft_cap
-
- if attn_type != AttentionType.DECODER:
- raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "XFormersAttentionImpl."
- )
-
- def forward(
- self,
- layer: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: XFormersAttentionMetadata,
- output: torch.Tensor | None = None,
- output_scale: torch.Tensor | None = None,
- output_block_scale: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Forward pass with XFormers.
-
- Args:
- query: shape = [num_tokens, num_heads, head_size]
- key: shape = [num_tokens, num_kv_heads, head_size]
- value: shape = [num_tokens, num_kv_heads, head_size]
- kv_cache: shape =
- [2, num_blocks, block_size, num_kv_heads, head_size]
- attn_metadata: Metadata for attention.
- Returns:
- shape = [num_tokens, num_heads * head_size]
- """
- assert output is not None, "Output tensor must be provided."
-
- if output_scale is not None or output_block_scale is not None:
- raise NotImplementedError(
- "fused output quantization is not yet supported"
- " for XFormersAttentionImpl"
- )
-
- if attn_metadata is None:
- # Profiling run.
- return output.fill_(0)
-
- # Cache the input KVs.
- key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
- # Reshape the input keys and values and store them in the cache.
- # Skip this if sharing KV cache with an earlier attention layer.
- # NOTE(woosuk): Here, key and value are padded while slot_mapping is
- # not padded. However, we don't need to do key[:num_actual_tokens]
- # and value[:num_actual_tokens] because the reshape_and_cache_flash
- # op uses the slot_mapping's shape to determine the number of
- # actual tokens.
- ops.reshape_and_cache_flash(
- key,
- value,
- key_cache,
- value_cache,
- attn_metadata.slot_mapping,
- self.kv_cache_dtype,
- layer._k_scale,
- layer._v_scale,
- )
-
- num_actual_tokens = attn_metadata.num_actual_tokens
- num_decode_tokens = attn_metadata.num_decode_tokens
- if prefill_meta := attn_metadata.prefill_metadata:
- descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
- unified_attention(
- q=query[num_decode_tokens:num_actual_tokens],
- k=key_cache,
- v=value_cache,
- out=output[num_decode_tokens:num_actual_tokens],
- cu_seqlens_q=prefill_meta.query_start_loc,
- max_seqlen_q=prefill_meta.max_query_len,
- seqused_k=prefill_meta.seq_lens,
- max_seqlen_k=prefill_meta.max_seq_len,
- softmax_scale=self.scale,
- causal=True,
- alibi_slopes=self.alibi_slopes,
- window_size=self.sliding_window,
- block_table=prefill_meta.block_table,
- softcap=self.logits_soft_cap,
- q_descale=None, # Not supported
- k_descale=layer._k_scale.expand(descale_shape),
- v_descale=layer._v_scale.expand(descale_shape),
- )
-
- if decode_meta := attn_metadata.decode_metadata:
- # Query for decode. KV is not needed because it is already cached.
- decode_query = query[:num_decode_tokens]
- # Reshape query to [1, B_T, G, H, D].
- q = decode_query.view(
- 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
- )
- # Reshape the k and v caches to [1, Bkv_T, G, H, D]
- cache_k = key_cache.view(
- 1, -1, self.num_kv_heads, 1, self.head_size
- ).expand(
- 1,
- -1,
- self.num_kv_heads,
- self.num_queries_per_kv,
- self.head_size,
- )
- cache_v = value_cache.view(
- 1, -1, self.num_kv_heads, 1, self.head_size
- ).expand(
- 1,
- -1,
- self.num_kv_heads,
- self.num_queries_per_kv,
- self.head_size,
- )
-
- attn_bias = decode_meta.attn_bias
- output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
- q,
- cache_k,
- cache_v,
- attn_bias=attn_bias,
- p=0.0,
- scale=self.scale,
- ).view(decode_query.shape)
-
- # Reshape the output tensor.
- return output
diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py
index 55710ad5cc693..8b0e8fd3a2410 100644
--- a/vllm/v1/core/block_pool.py
+++ b/vllm/v1/core/block_pool.py
@@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (
BlockHash,
+ BlockHashList,
+ BlockHashListWithBlockSize,
BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue,
@@ -133,6 +135,10 @@ class BlockPool:
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
+ hash_block_size: The block size of which the block hashes are computed.
+ The actual block size usually equals hash_block_size, but in cases
+ where different KV cache groups have different block sizes, the
+ actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events.
"""
@@ -140,11 +146,13 @@ class BlockPool:
self,
num_gpu_blocks: int,
enable_caching: bool,
+ hash_block_size: int,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
+ self.hash_block_size = hash_block_size
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
@@ -223,8 +231,20 @@ class BlockPool:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks
- new_block_hashes = request.block_hashes[num_cached_blocks:]
+ if block_size == self.hash_block_size:
+ # Common case.
+ block_hashes: BlockHashList = request.block_hashes
+ else:
+ # block_size is a multiple of hash_block_size. This happens when
+ # different KV cache groups have different block sizes.
+ assert block_size % self.hash_block_size == 0
+ # Recalculate block_hashes at the granularity of block_size, using
+ # the original block_hashes (at the granularity of hash_block_size).
+ block_hashes = BlockHashListWithBlockSize(
+ request.block_hashes, self.hash_block_size, block_size
+ )
+ new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None
)
diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py
index 1531b61f88fe2..fd1ec8e27fba2 100644
--- a/vllm/v1/core/kv_cache_coordinator.py
+++ b/vllm/v1/core/kv_cache_coordinator.py
@@ -2,15 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
+from math import lcm
from vllm.v1.core.block_pool import BlockPool
-from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
+from vllm.v1.core.kv_cache_utils import (
+ BlockHash,
+ BlockHashList,
+ BlockHashListWithBlockSize,
+ KVCacheBlock,
+)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
FullAttentionManager,
get_manager_for_kv_cache_spec,
)
-from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ KVCacheConfig,
+ KVCacheSpec,
+)
from vllm.v1.request import Request
@@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.block_pool = BlockPool(
- kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
+ kv_cache_config.num_blocks,
+ enable_caching,
+ hash_block_size,
+ enable_kv_cache_events,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
@@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
+ # For models using only Mamba, block_size is set to max_model_len when
+ # prefix caching is disabled, and hash_block_size validation is skipped.
+ assert not enable_caching or (hash_block_size == self.block_size), (
+ "UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
+ )
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
+ alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
@@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
+ # hash_block_size: the block size used to compute block hashes.
+ # The actual block size usually equals hash_block_size, but in cases where
+ # different KV cache groups have different block sizes, the actual block size
+ # can be a multiple of hash_block_size.
+ self.hash_block_size = hash_block_size
+ assert all(
+ g.kv_cache_spec.block_size % hash_block_size == 0
+ for g in kv_cache_config.kv_cache_groups
+ ), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
@@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
-
- if self.enable_caching:
- # this requirement is only needed for the prefix caching logic
- divisible = self.other_block_size % self.full_attention_block_size
- assert divisible == 0, (
- "KVCacheCoordinator assumes the block_size of full "
- "attention layers is divisible by other layers now."
- )
+ # The LCM of the block sizes of full attention and other attention.
+ # The cache hit length must be a multiple of the LCM of the block sizes
+ # to make sure the cache hit length is a multiple of the block size of
+ # each attention type. Requiring this because we don't support partial
+ # block cache hit yet.
+ self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
@@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
+ if self.full_attention_spec.block_size == self.hash_block_size:
+ # Common case.
+ full_attention_block_hashes: BlockHashList = block_hashes
+ else:
+ # block_size is a multiple of hash_block_size. This happens when different
+ # KV cache groups have different block sizes. In this case, we need to
+ # recalculate block_hashes at the granularity of block_size, using the
+ # original block_hashes (at the granularity of hash_block_size).
+ full_attention_block_hashes = BlockHashListWithBlockSize(
+ block_hashes, self.hash_block_size, self.full_attention_spec.block_size
+ )
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
- block_hashes=block_hashes,
+ block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
+ alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
+ if self.other_spec.block_size == self.hash_block_size:
+ # Common case.
+ other_block_hashes: BlockHashList = block_hashes
+ else:
+ # Similar to the full attention case, here we need to recalculate
+ # block_hashes at the granularity of block_size, using the original
+ # block_hashes (at the granularity of hash_block_size).
+ other_block_hashes = BlockHashListWithBlockSize(
+ block_hashes, self.hash_block_size, self.other_spec.block_size
+ )
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
- block_hashes=block_hashes,
+ block_hashes=other_block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
+ alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
@@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
max_model_len,
use_eagle,
enable_kv_cache_events,
- dcp_world_size=dcp_world_size,
- pcp_world_size=pcp_world_size,
+ dcp_world_size,
+ pcp_world_size,
+ hash_block_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
- dcp_world_size=dcp_world_size,
- pcp_world_size=pcp_world_size,
+ dcp_world_size,
+ pcp_world_size,
+ hash_block_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
- dcp_world_size=dcp_world_size,
- pcp_world_size=pcp_world_size,
+ dcp_world_size,
+ pcp_world_size,
+ hash_block_size,
)
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 2012c3fef88bc..b061e5cc831dd 100644
--- a/vllm/v1/core/kv_cache_manager.py
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -95,6 +95,7 @@ class KVCacheManager:
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
+ hash_block_size: int,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
@@ -107,28 +108,11 @@ class KVCacheManager:
self.enable_caching = enable_caching
self.use_eagle = use_eagle
self.log_stats = log_stats
- # FIXME: make prefix cache stats conditional on log_stats
+ # FIXME: make prefix cache stats conditional on log_stats. We still need
+ # this comment because when the log stats is enabled there are still
+ # potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
- self.block_size: int | None = None
- if self.enable_caching:
- assert (
- len(
- set(
- g.kv_cache_spec.block_size
- for g in kv_cache_config.kv_cache_groups
- )
- )
- == 1
- ), "Only one block size is supported for now"
- self.block_size = kv_cache_config.kv_cache_groups[
- 0
- ].kv_cache_spec.block_size
-
- if dcp_world_size * pcp_world_size > 1:
- assert len(kv_cache_config.kv_cache_groups) == 1
- self.block_size *= dcp_world_size * pcp_world_size
-
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
@@ -137,6 +121,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index b18ba8e8b2c7b..602eb81beb010 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -5,9 +5,9 @@
import copy
import os
from collections import defaultdict
-from collections.abc import Callable, Iterable, Sequence
-from dataclasses import dataclass
-from typing import Any, NewType, TypeAlias
+from collections.abc import Callable, Iterable, Iterator, Sequence
+from dataclasses import dataclass, replace
+from typing import Any, NewType, TypeAlias, overload
from vllm import envs
from vllm.config import VllmConfig
@@ -825,11 +825,11 @@ def get_num_blocks(
return num_blocks
-def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
+def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
- page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
+ page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1
return page_sizes.pop()
@@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
return len(page_sizes) == 1
+def unify_kv_cache_spec_page_size(
+ kv_cache_spec: dict[str, KVCacheSpec],
+) -> dict[str, KVCacheSpec]:
+ """
+ Unify the page size of the given KVCacheSpec. If the page size of all layers
+ are the same, return the original KVCacheSpec. If not same, unify the page
+ size by increasing the block size of layers with smaller page size. Raise
+ NotImplementedError if failed to unify the page size.
+
+ Args:
+ kv_cache_spec: The KVCacheSpec of each attention layer in the model
+
+ Returns:
+ The updated KVCacheSpec with the same page_size_bytes.
+ """
+ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
+ if len(page_sizes) <= 1:
+ # All layers have the same page size, no need to unify.
+ return kv_cache_spec
+
+ max_page_size = max(page_sizes)
+ new_kv_cache_spec = {}
+ for layer_name, layer_spec in kv_cache_spec.items():
+ if layer_spec.page_size_bytes == max_page_size:
+ new_kv_cache_spec[layer_name] = layer_spec
+ else:
+ layer_page_size = layer_spec.page_size_bytes
+ if max_page_size % layer_page_size != 0:
+ raise NotImplementedError(
+ "The page size of the layer is not divisible by the "
+ "maximum page size. Cannot unify by adjusting block_size."
+ )
+ ratio = max_page_size // layer_page_size
+ new_block_size = layer_spec.block_size * ratio
+ new_spec = replace(layer_spec, block_size=new_block_size)
+ assert new_spec.page_size_bytes == max_page_size
+ new_kv_cache_spec[layer_name] = new_spec
+ return new_kv_cache_spec
+
+
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
# kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec
@@ -971,7 +1011,16 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
- group_size = min([len(layers) for layers in same_type_layers.values()])
+ min_num_layers = min([len(layers) for layers in same_type_layers.values()])
+ group_size = min_num_layers
+ max_num_layers = max([len(layers) for layers in same_type_layers.values()])
+ if max_num_layers < min_num_layers * 1.25:
+ # If the number of layers is not much larger than the minimum number of layers,
+ # use the maximum number of layers as the group size to avoid too many padding
+ # layers. A typical example is gpt-oss-20b + eagle, with 12 sw + 13 full. We
+ # pad it to (13 sw, 13 full) instead of (12 sw, 24 full). 1.25 is just a
+ # magic number to avoid too many padding layers.
+ group_size = max_num_layers
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
@@ -1001,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
def get_kv_cache_config_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
- kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
@@ -1011,7 +1059,6 @@ def get_kv_cache_config_from_groups(
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
- kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
@@ -1055,7 +1102,9 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
- page_size = get_uniform_page_size(kv_cache_specs)
+ page_size = get_uniform_page_size(
+ [group.kv_cache_spec for group in kv_cache_groups]
+ )
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size
@@ -1157,7 +1206,8 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle
# attention free models.
return []
- elif is_kv_cache_spec_uniform(kv_cache_spec):
+
+ if is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
@@ -1167,14 +1217,16 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec)
- elif is_kv_cache_page_size_uniform(kv_cache_spec):
- # Model contains multiple attention types, but KV cache of all layers
- # have the same physical memory per block per layer. Split the layers
- # into groups with the same number of layers, and thus same total page
- # size.
- return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
- raise NotImplementedError
+ # As KVCacheManager can only allocate memory of one size, we need to unify
+ # the page size of the layers. For cases cannot be unified, this function
+ # will raise an error.
+ kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
+ # Model contains multiple attention types, but KV cache of all layers
+ # have the same physical memory per block per layer. Split the layers
+ # into groups with the same number of layers, and thus same total page
+ # size.
+ return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config(
@@ -1318,10 +1370,7 @@ def get_kv_cache_configs(
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append(
get_kv_cache_config_from_groups(
- vllm_config,
- kv_cache_groups_one_worker,
- kv_cache_spec_one_worker,
- available_memory_one_worker,
+ vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
)
)
@@ -1344,3 +1393,79 @@ def get_kv_cache_configs(
_report_kv_cache_config(vllm_config, kv_cache_config)
return kv_cache_configs
+
+
+class BlockHashListWithBlockSize:
+ """
+ Convert block-hash granularity from `hash_block_size` to `target_block_size`.
+ Used when KV cache groups have different block sizes: `hash_block_size`
+ is the size used to compute the original `block_hashes`; `target_block_size`
+ is the group's actual block size.
+
+ Currently, only scaling up by an integer factor is supported (i.e.,
+ `target_block_size` is a multiple of `hash_block_size`). Conversion is
+ performed lazily on access for efficiency, by concatenating consecutive
+ hashes at `hash_block_size` to form each hash at `target_block_size`.
+
+ Example (`hash_block_size` = 16, `target_block_size` = 32):
+ concatenating two 16-size hashes yields one 32-size hash:
+
+ Block hashes with block_size 16:
+ | Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
+ |-------------|------|-------|-------|-------|
+ | Hash | A | B | C | D |
+
+ Block hashes with block_size 32:
+ | Token Range | 0-31 | 32-63 |
+ |-------------|------|-------|
+ | Hash | AB | CD |
+
+ Args:
+ block_hashes: Block hashes to convert, computed at `hash_block_size`.
+ hash_block_size: Block size at which `block_hashes` were computed.
+ target_block_size: Desired block size; must be a multiple of `hash_block_size`.
+ """
+
+ def __init__(
+ self,
+ block_hashes: list[BlockHash],
+ hash_block_size: int,
+ target_block_size: int,
+ ):
+ self.block_hashes = block_hashes
+ assert target_block_size % hash_block_size == 0
+ self.scale_factor = target_block_size // hash_block_size
+
+ def __len__(self) -> int:
+ return len(self.block_hashes) // self.scale_factor
+
+ @overload
+ def __getitem__(self, idx: int) -> BlockHash: ...
+
+ @overload
+ def __getitem__(self, idx: slice) -> list[BlockHash]: ...
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ return self._get_value_at(idx)
+
+ if isinstance(idx, slice):
+ start, stop, step = idx.indices(len(self))
+ return [self._get_value_at(i) for i in range(start, stop, step)]
+
+ raise TypeError(f"Invalid index type: {type(idx)!r}")
+
+ def __iter__(self) -> Iterator[BlockHash]:
+ for i in range(len(self)):
+ yield self._get_value_at(i)
+
+ def _get_value_at(self, idx: int) -> BlockHash:
+ base = idx * self.scale_factor
+ end = base + self.scale_factor
+ merged_hash: bytes = self.block_hashes[base]
+ for i in range(base + 1, end):
+ merged_hash += self.block_hashes[i]
+ return BlockHash(merged_hash)
+
+
+BlockHashList = list[BlockHash] | BlockHashListWithBlockSize
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 9195b112d8690..bea2f865bad46 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -180,12 +180,13 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
- enable_caching=bool(self.cache_config.enable_prefix_caching),
+ enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
+ hash_block_size=self.block_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
@@ -470,6 +471,7 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request)
continue
+ request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external).
@@ -507,9 +509,9 @@ class Scheduler(SchedulerInterface):
not self.scheduler_config.enable_chunked_prefill
and num_new_tokens > token_budget
):
- self.waiting.pop_request()
- skipped_waiting_requests.prepend_request(request)
- continue
+ # If chunked_prefill is disabled,
+ # we can stop the scheduling here.
+ break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@@ -576,9 +578,6 @@ class Scheduler(SchedulerInterface):
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
- self._update_connector_prefix_cache_stats(
- request, num_external_computed_tokens
- )
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
@@ -590,6 +589,8 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
+ self._update_connector_prefix_cache_stats(request)
+
req_index += 1
self.running.append(request)
if self.log_stats:
@@ -1380,15 +1381,13 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods
########################################################################
- def _update_connector_prefix_cache_stats(
- self, request: Request, num_external_tokens: int
- ) -> None:
+ def _update_connector_prefix_cache_stats(self, request: Request) -> None:
if self.connector_prefix_cache_stats is None:
return
self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens,
- num_hits=num_external_tokens,
+ num_hits=request.num_external_computed_tokens,
preempted=request.num_preemptions > 0,
)
@@ -1571,9 +1570,11 @@ class Scheduler(SchedulerInterface):
marked_invalid_block = True
# Truncate the computed tokens at the first failed block
request.num_computed_tokens = idx * self.block_size
- total_affected_tokens += (
+ num_affected_tokens = (
req_num_computed_tokens - request.num_computed_tokens
)
+ total_affected_tokens += num_affected_tokens
+ request.num_external_computed_tokens -= num_affected_tokens
if is_affected:
if not marked_invalid_block:
diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
index d90ec550f7666..4aeb17a156bb3 100644
--- a/vllm/v1/core/single_type_kv_cache_manager.py
+++ b/vllm/v1/core/single_type_kv_cache_manager.py
@@ -7,7 +7,7 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
-from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
+from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
@@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
@abstractmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
+ alignment_tokens: The returned cache hit length (in tokens) should
+ be a multiple of this value (in tokens). By default, it should
+ be set to the block_size.
+ dcp_world_size: The world size of decode context parallelism.
+ pcp_world_size: The world size of prefill context parallelism.
Returns:
A list of cached blocks with skipped blocks replaced by null block
@@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
- kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
+ kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
else:
break
if use_eagle and computed_blocks[0]:
+ # Need to drop the last matched block if eagle is enabled.
+ for computed in computed_blocks:
+ computed.pop()
+ while (
+ block_size != alignment_tokens # Faster for common case.
+ and len(computed_blocks[0]) * block_size % alignment_tokens != 0
+ ):
for computed in computed_blocks:
computed.pop()
return computed_blocks
@@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))
)
+ block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
@@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
+ # Skip prefix matching check if the block is not aligned with
+ # `alignment_tokens`.
+ if (
+ num_contiguous_blocks == 0
+ and block_size != alignment_tokens # Faster for common case.
+ and (i + 1) * block_size % alignment_tokens != 0
+ ):
+ continue
+ # Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
@@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
+ while (
+ block_size != alignment_tokens # Faster for common case.
+ and len(computed_blocks[0]) * block_size % alignment_tokens != 0
+ ):
+ for computed in computed_blocks:
+ computed.pop()
if use_eagle and computed_blocks[0]:
+ assert kv_cache_spec.block_size == alignment_tokens, (
+ "aligned_length is not compatible with eagle now"
+ )
for computed in computed_blocks:
computed.pop()
return computed_blocks
@@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
+ dcp_world_size: The world size of decode context parallelism.
+ pcp_world_size: The world size of prefill context parallelism.
+ alignment_tokens: The returned cache hit length (in tokens) should
+ be a multiple of this value (in tokens).
Returns:
A list of cached blocks
@@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
+ assert kv_cache_spec.block_size == alignment_tokens, (
+ "KV cache groups with different block sizes are not compatible with "
+ "chunked local attention now"
+ )
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
- max_num_blocks = max_length // kv_cache_spec.block_size
+ block_size = kv_cache_spec.block_size
+ max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
+ # When enable Mamba prefix caching, `block_size` will be aligned
+ # across full attention layers and Mamba layers to ensure the
+ # prefix hit length aligned at block
+ if (
+ block_size != alignment_tokens # Faster for common case.
+ and (i + 1) * block_size % alignment_tokens != 0
+ ):
+ continue
for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0])
@@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py
index 3f621d77c0241..ce2aae77108da 100644
--- a/vllm/v1/engine/__init__.py
+++ b/vllm/v1/engine/__init__.py
@@ -72,6 +72,14 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None
+ @property
+ def params(self) -> SamplingParams | PoolingParams:
+ """Return the processed params (sampling or pooling)."""
+ if self.sampling_params is not None:
+ return self.sampling_params
+ assert self.pooling_params is not None
+ return self.pooling_params
+
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py
index c64b3cccfc652..827a2736af284 100644
--- a/vllm/v1/engine/async_llm.py
+++ b/vllm/v1/engine/async_llm.py
@@ -31,7 +31,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
-from vllm.utils.func_utils import deprecate_kwargs
from vllm.utils.math_utils import cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
@@ -195,12 +194,6 @@ class AsyncLLM(EngineClient):
self.profiler = None
@classmethod
- @deprecate_kwargs(
- "disable_log_requests",
- additional_message=(
- "This argument will have no effect. Use `enable_log_requests` instead."
- ),
- )
def from_vllm_config(
cls,
vllm_config: VllmConfig,
@@ -213,7 +206,6 @@ class AsyncLLM(EngineClient):
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
- disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM":
# Create the LLMEngine.
return cls(
@@ -321,14 +313,15 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
+ # Use cloned params that may have been updated in process_inputs()
+ params = request.params
+
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
return queue
- # Get the updated SamplingParams from the request, which
- # were cloned/updated in processor.process_inputs above.
- parent_params = request.sampling_params
- assert parent_params is not None
+ parent_params = params
+ assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index e403cea87788b..dffe05445ee46 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -250,6 +250,9 @@ class LLMEngine:
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
+ # Use cloned params that may have been updated in process_inputs()
+ params = request.params
+
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
@@ -262,10 +265,10 @@ class LLMEngine:
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
- request_id, params = parent_req.get_child_info(idx)
+ request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
- child_request.sampling_params = params
+ child_request.sampling_params = child_params
# Make a new RequestState and queue.
self.output_processor.add_request(
diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py
index e2d82241ce210..bd18a152ffc08 100644
--- a/vllm/v1/metrics/loggers.py
+++ b/vllm/v1/metrics/loggers.py
@@ -440,57 +440,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
# Setting default values
self.record_sleep_state()
- # GPU cache
- #
- # Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- gauge_gpu_cache_usage = self._gauge_cls(
- name="vllm:gpu_cache_usage_perc",
- documentation=(
- "GPU KV-cache usage. 1 means 100 percent usage."
- "DEPRECATED: Use vllm:kv_cache_usage_perc instead."
- ),
- multiprocess_mode="mostrecent",
- labelnames=labelnames,
- )
- self.gauge_gpu_cache_usage = make_per_engine(
- gauge_gpu_cache_usage, engine_indexes, model_name
- )
-
- # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- counter_gpu_prefix_cache_queries = self._counter_cls(
- name="vllm:gpu_prefix_cache_queries",
- documentation=(
- "GPU prefix cache queries, in terms of number of queried"
- "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."
- ),
- labelnames=labelnames,
- )
- self.counter_gpu_prefix_cache_queries = make_per_engine(
- counter_gpu_prefix_cache_queries, engine_indexes, model_name
- )
-
- # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- counter_gpu_prefix_cache_hits = self._counter_cls(
- name="vllm:gpu_prefix_cache_hits",
- documentation=(
- "GPU prefix cache hits, in terms of number of cached "
- "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."
- ),
- labelnames=labelnames,
- )
- self.counter_gpu_prefix_cache_hits = make_per_engine(
- counter_gpu_prefix_cache_hits, engine_indexes, model_name
- )
-
gauge_kv_cache_usage = self._gauge_cls(
name="vllm:kv_cache_usage_perc",
documentation="KV-cache usage. 1 means 100 percent usage.",
@@ -735,39 +684,41 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
)
# Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
- # TODO: in 0.12, only enable if show_hidden_metrics=True
- histogram_time_per_output_token = self._histogram_cls(
- name="vllm:time_per_output_token_seconds",
- documentation=(
- "Histogram of time per output token in seconds."
- "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
- ),
- buckets=[
- 0.01,
- 0.025,
- 0.05,
- 0.075,
- 0.1,
- 0.15,
- 0.2,
- 0.3,
- 0.4,
- 0.5,
- 0.75,
- 1.0,
- 2.5,
- 5.0,
- 7.5,
- 10.0,
- 20.0,
- 40.0,
- 80.0,
- ],
- labelnames=labelnames,
- )
- self.histogram_time_per_output_token = make_per_engine(
- histogram_time_per_output_token, engine_indexes, model_name
- )
+ # With 0.12.x you can enable with --show-hidden-metrics-for-version=0.11
+ # TODO: remove in 0.13.0
+ if self.show_hidden_metrics:
+ histogram_time_per_output_token = self._histogram_cls(
+ name="vllm:time_per_output_token_seconds",
+ documentation=(
+ "Histogram of time per output token in seconds."
+ "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
+ ),
+ buckets=[
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.075,
+ 0.1,
+ 0.15,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 20.0,
+ 40.0,
+ 80.0,
+ ],
+ labelnames=labelnames,
+ )
+ self.histogram_time_per_output_token = make_per_engine(
+ histogram_time_per_output_token, engine_indexes, model_name
+ )
histogram_inter_token_latency = self._histogram_cls(
name="vllm:inter_token_latency_seconds",
@@ -966,20 +917,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.gauge_scheduler_waiting[engine_idx].set(
scheduler_stats.num_waiting_reqs
)
- if self.show_hidden_metrics:
- self.gauge_gpu_cache_usage[engine_idx].set(
- scheduler_stats.kv_cache_usage
- )
self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)
- if self.show_hidden_metrics:
- self.counter_gpu_prefix_cache_queries[engine_idx].inc(
- scheduler_stats.prefix_cache_stats.queries
- )
- self.counter_gpu_prefix_cache_hits[engine_idx].inc(
- scheduler_stats.prefix_cache_stats.hits
- )
-
self.counter_prefix_cache_queries[engine_idx].inc(
scheduler_stats.prefix_cache_stats.queries
)
@@ -1050,7 +989,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.histogram_time_to_first_token[engine_idx].observe(ttft)
for itl in iteration_stats.inter_token_latencies_iter:
self.histogram_inter_token_latency[engine_idx].observe(itl)
- self.histogram_time_per_output_token[engine_idx].observe(itl)
+ if self.show_hidden_metrics:
+ self.histogram_time_per_output_token[engine_idx].observe(itl)
for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason][
diff --git a/vllm/v1/request.py b/vllm/v1/request.py
index 3d92906fbf4b1..366cdadf5a583 100644
--- a/vllm/v1/request.py
+++ b/vllm/v1/request.py
@@ -121,6 +121,9 @@ class Request:
# The number of requests being preempted by the scheduler
self.num_preemptions = 0
+ # The number of tokens that have been computed remotely.
+ self.num_external_computed_tokens = 0
+
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py
index 39c63fe31ad2c..c75b4f0543c0d 100644
--- a/vllm/v1/sample/sampler.py
+++ b/vllm/v1/sample/sampler.py
@@ -81,7 +81,10 @@ class Sampler(nn.Module):
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
- raw_logprobs = logits.clone()
+ if logits.dtype == torch.float32:
+ raw_logprobs = logits.clone()
+ else:
+ raw_logprobs = logits.to(torch.float32)
# Use float32 for the logits.
logits = logits.to(torch.float32)
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index 0df9cd3214e53..7b9037c03d4f0 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -40,6 +40,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
+from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
logger = init_logger(__name__)
@@ -65,6 +66,7 @@ class EagleProposer:
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
+ self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
@@ -83,12 +85,15 @@ class EagleProposer:
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
+ self.eagle3_use_aux_hidden_state: bool = (
+ self._get_eagle3_use_aux_hidden_state_from_config()
+ )
self.use_cuda_graph = False
- compilation_config = self.vllm_config.compilation_config
- if compilation_config.mode == CompilationMode.VLLM_COMPILE:
- cudagraph_mode = compilation_config.cudagraph_mode
+ self.compilation_config = self.vllm_config.compilation_config
+ if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
+ cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
CUDAGraphMode.PIECEWISE
):
@@ -103,13 +108,6 @@ class EagleProposer:
and not self.speculative_config.enforce_eager
)
- self.cudagraph_batch_sizes = (
- (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes))
- if self.use_cuda_graph
- else []
- )
-
- self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes)
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=device
@@ -275,12 +273,24 @@ class EagleProposer:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
+ num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=num_tokens,
+ num_tokens_padded=num_tokens,
+ )
+
cudagraph_runtime_mode = CUDAGraphMode.NONE
- if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
- num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
+ if (
+ self.use_cuda_graph
+ and num_tokens_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
- num_input_tokens = num_tokens
+ num_input_tokens = num_tokens_dp_padded
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[self.dp_rank] = num_input_tokens
+
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
@@ -304,6 +314,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@@ -366,12 +377,23 @@ class EagleProposer:
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
- if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
- input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
+ batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=batch_size,
+ num_tokens_padded=batch_size,
+ )
+
+ if (
+ self.use_cuda_graph
+ and batch_size_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
- input_batch_size = batch_size
+ input_batch_size = batch_size_dp_padded
cudagraph_runtime_mode = CUDAGraphMode.NONE
+ if batch_size_across_dp is not None:
+ batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
@@ -472,6 +494,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
+ num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@@ -777,7 +800,10 @@ class EagleProposer:
self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
- if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
+ if (
+ self.use_cuda_graph
+ and num_tokens <= self.compilation_config.max_cudagraph_capture_size
+ ):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
@@ -1029,11 +1055,11 @@ class EagleProposer:
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
- and torch.allclose(
+ # TODO: Offload to CPU for comparison to avoid extra GPU memory
+ # usage in CI testing environments with limited GPU memory
+ and torch.equal(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
- rtol=1e-5,
- atol=1e-7,
)
):
share_embeddings = True
@@ -1079,8 +1105,11 @@ class EagleProposer:
hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor)
+ # TODO: Offload to CPU for comparison to avoid extra GPU memory
+ # usage in CI testing environments with limited GPU memory
and torch.equal(
- target_language_model.lm_head.weight, self.model.lm_head.weight
+ target_language_model.lm_head.weight.cpu(),
+ self.model.lm_head.weight.cpu(),
)
):
share_lm_head = True
@@ -1111,33 +1140,56 @@ class EagleProposer:
self,
num_tokens: int,
use_cudagraphs=True,
+ is_graph_capturing=False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
- if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
- num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
- with set_forward_context(
- None,
- self.vllm_config,
- num_tokens=num_tokens,
- cudagraph_runtime_mode=(
- CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
- ),
+ # FIXME: when using tree-based specdec, adjust number of forward-passes
+ # according to the depth of the tree.
+ for fwd_idx in range(
+ self.num_speculative_tokens if not is_graph_capturing else 1
):
- if self.supports_mm_inputs:
- input_ids = None
- inputs_embeds = self.inputs_embeds[:num_tokens]
- else:
- input_ids = self.input_ids[:num_tokens]
- inputs_embeds = None
+ if fwd_idx <= 1:
+ num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=num_tokens,
+ num_tokens_padded=num_tokens,
+ )
+ if (
+ cudagraphs_enabled
+ and num_tokens_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(
+ num_tokens_dp_padded
+ )
+ else:
+ num_input_tokens = num_tokens_dp_padded
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[self.dp_rank] = num_input_tokens
- self.model(
- input_ids=input_ids,
- positions=self._get_positions(num_tokens),
- hidden_states=self.hidden_states[:num_tokens],
- inputs_embeds=inputs_embeds,
- )
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
+ if cudagraphs_enabled
+ else CUDAGraphMode.NONE,
+ ):
+ if self.supports_mm_inputs:
+ input_ids = None
+ inputs_embeds = self.inputs_embeds[:num_input_tokens]
+ else:
+ input_ids = self.input_ids[:num_input_tokens]
+ inputs_embeds = None
+
+ self.model(
+ input_ids=input_ids,
+ positions=self._get_positions(num_input_tokens),
+ hidden_states=self.hidden_states[:num_input_tokens],
+ inputs_embeds=inputs_embeds,
+ )
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
@@ -1164,6 +1216,22 @@ class EagleProposer:
)
return builder
+ def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
+ """
+ Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
+ hidden states and directly uses the last layer output just like eagle1.
+ They might indicate this by setting "use_aux_hidden_state" to False
+ inside the "eagle_config" dict of their hf_config.
+ """
+ if self.method != "eagle3":
+ return False
+ # Assume that eagle3 heads use aux hidden states by default
+ use_aux_hidden_state = True
+ eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
+ if eagle_config is not None:
+ use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
+ return use_aux_hidden_state
+
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
@@ -1187,6 +1255,28 @@ class EagleProposer:
== 1
), "All eagle layers should belong to the same kv cache group"
+ def _pad_batch_across_dp(
+ self,
+ num_tokens_unpadded: int,
+ num_tokens_padded: int,
+ ) -> tuple[int, torch.Tensor]:
+ # TODO(Flechman): support DBO ubatching
+ ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
+ num_tokens_unpadded=num_tokens_unpadded,
+ parallel_config=self.vllm_config.parallel_config,
+ allow_microbatching=False,
+ allow_dp_padding=self.use_cuda_graph,
+ num_tokens_padded=num_tokens_padded,
+ uniform_decode=None,
+ num_scheduled_tokens_per_request=None,
+ )
+ assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
+
+ num_tokens_dp_padded = num_tokens_padded
+ if num_toks_across_dp is not None:
+ num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
+ return num_tokens_dp_padded, num_toks_across_dp
+
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
index 76e17f3797a1a..37ec0fb97e06b 100644
--- a/vllm/v1/worker/block_table.py
+++ b/vllm/v1/worker/block_table.py
@@ -84,7 +84,7 @@ class BlockTable:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
- # DCP might not be initialized in testing
+ # PCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try:
@@ -268,6 +268,11 @@ class MultiGroupBlockTable:
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
+ try:
+ pcp_world_size = get_pcp_group().world_size
+ except AssertionError:
+ # PCP might not be initialized in testing
+ pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
@@ -280,12 +285,14 @@ class MultiGroupBlockTable:
f"must match block_sizes length ({len(block_sizes)})"
)
+ total_cp_world_size = dcp_world_size * pcp_world_size
+
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(
- cdiv(max_model_len, block_size * dcp_world_size),
+ cdiv(max_model_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_batched_tokens,
diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py
index e523090aa2172..421fb29a7f87f 100644
--- a/vllm/v1/worker/gpu/async_utils.py
+++ b/vllm/v1/worker/gpu/async_utils.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
-import numpy as np
import torch
from vllm.v1.outputs import (
@@ -18,7 +17,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
- num_sampled_tokens: np.ndarray,
+ num_sampled_tokens: torch.Tensor,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
@@ -52,6 +51,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
)
else:
self.logprobs_tensors = None
+ self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True)
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
@@ -63,6 +63,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
+ num_sampled_tokens_np = self.num_sampled_tokens.numpy()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
@@ -71,7 +72,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids)
for i in range(num_reqs):
- del sampled_token_ids[i][self.num_sampled_tokens[i] :]
+ del sampled_token_ids[i][num_sampled_tokens_np[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None:
diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py
index 222db565dff17..4510a1c5ca1e9 100644
--- a/vllm/v1/worker/gpu/attn_utils.py
+++ b/vllm/v1/worker/gpu/attn_utils.py
@@ -3,6 +3,7 @@
from collections.abc import Sequence
from typing import Any, cast
+import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionBackend
@@ -145,8 +146,9 @@ def build_attn_metadata(
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
- seq_lens: CpuGpuBuffer,
- num_computed_tokens_cpu: torch.Tensor,
+ seq_lens: torch.Tensor,
+ seq_lens_np: np.ndarray,
+ num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
@@ -154,9 +156,9 @@ def build_attn_metadata(
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
- seq_lens_gpu = seq_lens.gpu[:num_reqs]
- seq_lens_cpu = seq_lens.cpu[:num_reqs]
- max_seq_len = int(seq_lens.np[:num_reqs].max())
+ seq_lens = seq_lens[:num_reqs]
+ seq_lens_cpu = torch.from_numpy(seq_lens_np)
+ max_seq_len = int(seq_lens_np.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
@@ -167,7 +169,7 @@ def build_attn_metadata(
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
- seq_lens=seq_lens_gpu,
+ seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py
index ff24e88ede2c0..b31e9b179d26c 100644
--- a/vllm/v1/worker/gpu/block_table.py
+++ b/vllm/v1/worker/gpu/block_table.py
@@ -3,10 +3,9 @@
from collections.abc import Iterable
import torch
-import triton
-import triton.language as tl
from vllm.attention.backends.utils import PAD_SLOT_ID
+from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py
index 31a706475243c..ba783e2d0c6fb 100644
--- a/vllm/v1/worker/gpu/cudagraph_utils.py
+++ b/vllm/v1/worker/gpu/cudagraph_utils.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import gc
-from contextlib import contextmanager
+from unittest.mock import patch
import numpy as np
import torch
@@ -27,9 +26,11 @@ class CudaGraphManager:
device: torch.device,
):
self.vllm_config = vllm_config
+ self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
@@ -39,9 +40,11 @@ class CudaGraphManager:
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_capture_sizes is not None:
- self.cudagraph_sizes = sorted(
- self.compilation_config.cudagraph_capture_sizes
- )
+ cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
+ # Limit the cudagraph sizes to the max decode batch size.
+ self.cudagraph_sizes = [
+ x for x in cudagraph_sizes if x <= self.max_num_reqs
+ ]
else:
self.cudagraph_sizes = []
self.padded_sizes = self._init_padded_sizes()
@@ -54,9 +57,10 @@ class CudaGraphManager:
if not self.cudagraph_mode.has_full_cudagraphs():
# Full cuda graphs are not used.
return {}
+ if not self.cudagraph_sizes:
+ return {}
padded_sizes: dict[int, int] = {}
- assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
@@ -97,14 +101,16 @@ class CudaGraphManager:
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
- positions = input_buffers.positions.gpu[:batch_size]
+ positions = input_buffers.positions[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
- input_buffers.seq_lens.np[:batch_size] = self.max_model_len
- input_buffers.seq_lens.np[batch_size:] = 0
- input_buffers.seq_lens.copy_to_gpu()
+ # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
+ # for seq_lens. This leads to a mismatch between seq_lens (GPU) and
+ # seq_lens_np (CPU), which might cause issues in some attention backends.
+ input_buffers.seq_lens[:batch_size] = 1
+ input_buffers.seq_lens[batch_size:] = 0
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
@@ -115,6 +121,7 @@ class CudaGraphManager:
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
+ seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
@@ -135,6 +142,7 @@ class CudaGraphManager:
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = model(
@@ -143,15 +151,16 @@ class CudaGraphManager:
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
- torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with (
+ patch("torch.cuda.empty_cache", lambda: None),
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
torch.cuda.graph(graph, self.pool),
@@ -178,7 +187,7 @@ class CudaGraphManager:
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
- with freeze_gc(), graph_capture(device=self.device):
+ with graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
@@ -194,13 +203,3 @@ class CudaGraphManager:
self.graphs[batch_size].replay()
assert self.hidden_states is not None
return self.hidden_states[:batch_size]
-
-
-@contextmanager
-def freeze_gc():
- gc.collect()
- gc.freeze()
- try:
- yield
- finally:
- gc.unfreeze()
diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py
index 89f375649146f..2a7048ae3c0e0 100644
--- a/vllm/v1/worker/gpu/input_batch.py
+++ b/vllm/v1/worker/gpu/input_batch.py
@@ -4,12 +4,10 @@ from dataclasses import dataclass
from typing import Any
import numba
-import numba.types as types
import numpy as np
import torch
-import triton
-import triton.language as tl
+from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
@@ -33,9 +31,13 @@ class InputBuffers:
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
- self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
+ self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
- self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
+ self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
+ self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
+
+ # Spec decoding.
+ self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
@@ -65,6 +67,7 @@ class InputBatch:
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
+ num_draft_tokens: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
@@ -81,8 +84,10 @@ class InputBatch:
# layer_name -> Metadata
attn_metadata: dict[str, Any]
- # [num_reqs]
+ # [total_num_logits]
logits_indices: torch.Tensor
+ # [num_reqs + 1]
+ cu_num_logits: torch.Tensor
@classmethod
def make_dummy(
@@ -108,15 +113,18 @@ class InputBatch:
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
- input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
- input_buffers.seq_lens.np[num_reqs:] = 0
- seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
- seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
+ seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
+ seq_lens_np[-1] += num_tokens % num_reqs
+ input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
+ input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
+ input_buffers.seq_lens[num_reqs:] = 0
+ seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
- positions = input_buffers.positions.copy_to_gpu(num_tokens)
+ positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
+ cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -125,6 +133,7 @@ class InputBatch:
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
+ num_draft_tokens=0,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
@@ -133,133 +142,256 @@ class InputBatch:
positions=positions,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
+ cu_num_logits=cu_num_logits,
)
-# NOTE: With the type annotations, this function is pre-compiled
-# before the first call.
-@numba.jit(
- [
- types.none(
- types.int32[:], # idx_mapping
- types.int32[:, :], # token_ids
- types.int32[:], # num_computed_tokens
- types.int32[:], # num_scheduled_tokens
- types.int32[:], # input_ids
- types.int64[:], # positions
- types.int32[:], # query_start_loc
- types.int32[:], # seq_lens
- )
- ],
- nopython=True,
- cache=True,
-)
-def _prepare_inputs(
- idx_mapping: np.ndarray, # batch_idx -> req_idx
- token_ids: np.ndarray, # [N, max_model_len]
- num_computed_tokens: np.ndarray, # [N]
- num_scheduled_tokens: np.ndarray, # [B]
- input_ids: np.ndarray, # [num_input_tokens]
- positions: np.ndarray, # [num_input_tokens]
+@numba.njit(cache=True)
+def _prepare_prefill_inputs(
+ idx_mapping: np.ndarray, # [B]
+ query_lens: np.ndarray, # [B]
query_start_loc: np.ndarray, # [B + 1]
- seq_lens: np.ndarray, # [B]
+ prefill_token_ids: np.ndarray, # [N, max_model_len]
+ num_computed_prefill_tokens: np.ndarray, # [N]
+ input_ids: np.ndarray, # [num_input_tokens]
) -> None:
- num_reqs = num_scheduled_tokens.shape[0]
- query_start_loc[0] = 0
-
- cu_num_tokens = 0
+ num_reqs = idx_mapping.shape[0]
+ query_starts = query_start_loc[:num_reqs]
+ query_ends = query_start_loc[1 : num_reqs + 1]
+ starts = num_computed_prefill_tokens[idx_mapping]
+ ends = starts + query_lens
for i in range(num_reqs):
- req_idx = idx_mapping[i]
- query_len = num_scheduled_tokens[i]
- start = num_computed_tokens[req_idx]
- end = start + query_len
- seq_lens[i] = end
-
- start_idx = cu_num_tokens
- end_idx = start_idx + query_len
- input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
- positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
-
- cu_num_tokens = end_idx
- query_start_loc[i + 1] = cu_num_tokens
-
- # Pad the inputs for CUDA graphs.
- # Note: pad query_start_loc to be non-decreasing, as kernels
- # like FlashAttention requires that
- query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
- # Fill unused with 0 for full cuda graph mode.
- seq_lens[num_reqs:].fill(0)
+ input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
+ idx_mapping[i], starts[i] : ends[i]
+ ]
-def prepare_inputs(
+def prepare_prefill_inputs(
idx_mapping: np.ndarray,
- prefill_token_ids: np.ndarray,
- num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
- input_ids: CpuGpuBuffer,
- positions: CpuGpuBuffer,
- query_start_loc: CpuGpuBuffer,
- seq_lens: CpuGpuBuffer,
- num_tokens: int,
+ query_start_loc: np.ndarray,
+ prefill_token_ids: np.ndarray,
+ num_computed_prefill_tokens: np.ndarray,
+ input_ids: np.ndarray,
) -> None:
- _prepare_inputs(
+ _prepare_prefill_inputs(
idx_mapping,
- prefill_token_ids,
- num_computed_tokens,
num_scheduled_tokens,
- input_ids.np,
- positions.np,
- query_start_loc.np,
- seq_lens.np,
+ query_start_loc,
+ prefill_token_ids,
+ num_computed_prefill_tokens,
+ input_ids,
)
- input_ids.copy_to_gpu(num_tokens)
- positions.copy_to_gpu(num_tokens)
- # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
- # tensors from CPU to GPU, because they may include paddings needed
- # for full CUDA graph mode.
- query_start_loc.copy_to_gpu()
- seq_lens.copy_to_gpu()
- return
@triton.jit
-def _combine_last_token_ids_kernel(
+def _prepare_pos_seq_lens_kernel(
+ pos_ptr,
+ seq_lens_ptr,
+ idx_mapping_ptr,
+ query_start_loc_ptr,
+ num_computed_tokens_ptr,
+ max_num_reqs,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_id = tl.program_id(0)
+ num_reqs = tl.num_programs(0) - 1
+ if req_id == num_reqs:
+ # Pad unused seq_lens as 0 for full CUDA graphs.
+ for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < max_num_reqs
+ tl.store(seq_lens_ptr + block, 0, mask=mask)
+ return
+
+ req_state_idx = tl.load(idx_mapping_ptr + req_id)
+ num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
+
+ start = tl.load(query_start_loc_ptr + req_id)
+ end = tl.load(query_start_loc_ptr + req_id + 1)
+ query_len = end - start
+
+ seq_len = num_computed_tokens + query_len
+ tl.store(seq_lens_ptr + req_id, seq_len)
+
+ for i in tl.range(0, query_len, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < query_len
+ pos = num_computed_tokens + block
+ tl.store(pos_ptr + start + block, pos, mask=mask)
+
+
+def prepare_pos_seq_lens(
+ idx_mapping: torch.Tensor,
+ query_start_loc: torch.Tensor,
+ num_computed_tokens: torch.Tensor,
+ pos: torch.Tensor,
+ seq_lens: torch.Tensor,
+) -> None:
+ num_reqs = idx_mapping.shape[0]
+ # NOTE(woosuk): We do +1 because the last thread block is used
+ # to pad unused seq_lens as 0 for full CUDA graphs.
+ _prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
+ pos,
+ seq_lens,
+ idx_mapping,
+ query_start_loc,
+ num_computed_tokens,
+ seq_lens.shape[0],
+ BLOCK_SIZE=1024,
+ )
+
+
+@triton.jit
+def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
- last_token_ids_ptr,
+ last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
+ draft_tokens_ptr,
+ draft_tokens_stride,
+ cu_num_logits_ptr,
+ logits_indices_ptr,
+ BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
+ # Get the number of logits and draft tokens.
+ cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
+ cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
+ num_logits = cu_num_logits_end - cu_num_logits_start
+ num_draft_tokens = num_logits - 1
+
+ # Compute the logits indices.
+ block = tl.arange(0, BLOCK_SIZE)
+ query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
+ logits_start = query_end - num_logits
+ tl.store(
+ logits_indices_ptr + cu_num_logits_start + block,
+ logits_start + block,
+ mask=block < num_logits,
+ )
+
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
- # Handling prefill tokens.
+ # Handling prefill tokens. No sampled or draft tokens.
return
- last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
- end = tl.load(query_start_loc_ptr + batch_idx + 1)
- tl.store(input_ids_ptr + end - 1, last_token_id)
+ # Write the last sampled token ID to input_ids.
+ last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
+ tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
+
+ # Write the draft tokens (if any) to input_ids.
+ if num_draft_tokens > 0:
+ mask = block < num_draft_tokens
+ draft_tokens = tl.load(
+ draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
+ mask=mask,
+ )
+ tl.store(
+ input_ids_ptr + query_end - num_draft_tokens + block,
+ draft_tokens,
+ mask=mask,
+ )
-def combine_last_token_ids(
+def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
- last_token_ids: torch.Tensor,
+ last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
+ draft_tokens: torch.Tensor,
+ cu_num_logits: torch.Tensor,
+ num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
- _combine_last_token_ids_kernel[(num_reqs,)](
+ num_speculative_steps = draft_tokens.shape[-1]
+
+ logits_indices = torch.empty(
+ num_logits,
+ dtype=torch.int64,
+ device=input_ids.device,
+ )
+ _combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
- last_token_ids,
+ last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
+ draft_tokens,
+ draft_tokens.stride(0),
+ cu_num_logits,
+ logits_indices,
+ # NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
+ # in addition to all draft tokens.
+ BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
+ )
+ return logits_indices
+
+
+@triton.jit
+def _post_update_kernel(
+ idx_mapping_ptr,
+ num_computed_tokens_ptr,
+ last_sampled_tokens_ptr,
+ sampled_tokens_ptr,
+ sampled_tokens_stride,
+ num_sampled_ptr,
+ num_rejected_ptr,
+ query_start_loc_ptr,
+):
+ req_id = tl.program_id(0)
+ req_state_idx = tl.load(idx_mapping_ptr + req_id)
+
+ num_sampled = tl.load(num_sampled_ptr + req_id)
+ if num_sampled > 0:
+ token_id = tl.load(
+ sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
+ )
+ tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
+
+ query_start = tl.load(query_start_loc_ptr + req_id)
+ query_end = tl.load(query_start_loc_ptr + req_id + 1)
+ query_len = query_end - query_start
+ num_rejected = tl.load(num_rejected_ptr + req_id)
+
+ num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
+ num_computed += query_len - num_rejected
+ tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
+
+
+def post_update(
+ # [num_reqs]
+ idx_mapping: torch.Tensor,
+ # [max_num_reqs]
+ num_computed_tokens: torch.Tensor,
+ # [max_num_reqs]
+ last_sampled_tokens: torch.Tensor,
+ # [num_reqs, num_speculative_steps + 1]
+ sampled_tokens: torch.Tensor,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [num_reqs + 1]
+ query_start_loc: torch.Tensor,
+) -> None:
+ num_reqs = idx_mapping.shape[0]
+ _post_update_kernel[(num_reqs,)](
+ idx_mapping,
+ num_computed_tokens,
+ last_sampled_tokens,
+ sampled_tokens,
+ sampled_tokens.stride(0),
+ num_sampled,
+ num_rejected,
+ query_start_loc,
+ num_warps=1,
)
- return input_ids
diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py
index 9ca37ff282d82..e34a45f979807 100644
--- a/vllm/v1/worker/gpu/model_runner.py
+++ b/vllm/v1/worker/gpu/model_runner.py
@@ -39,10 +39,17 @@ from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
- combine_last_token_ids,
- prepare_inputs,
+ combine_sampled_and_draft_tokens,
+ post_update,
+ prepare_pos_seq_lens,
+ prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
+from vllm.v1.worker.gpu.spec_decode import init_speculator
+from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
+ get_num_rejected,
+ rejection_sample,
+)
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -94,14 +101,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event()
+ self.spec_decode_event = torch.cuda.Event()
else:
self.input_prep_event = None
self.structured_outputs_event = None
+ self.spec_decode_event = None
+
+ if self.speculative_config is not None:
+ self.do_spec_decode = True
+ self.num_speculative_steps = self.speculative_config.num_speculative_tokens
+ self.speculator = init_speculator(self.vllm_config, self.device)
+ else:
+ self.do_spec_decode = False
+ self.num_speculative_steps = 0
+ self.speculator = None
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
+ num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
@@ -142,6 +161,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
+ if self.do_spec_decode:
+ self.speculator.load_model(self.model)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
@@ -179,6 +200,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
+ # TODO(woosuk): Support other backends.
+ if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
+ raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
@@ -196,8 +220,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
- num_computed_tokens_cpu = torch.zeros(
- input_batch.num_reqs, dtype=torch.int32, device="cpu"
+ num_computed_tokens = torch.zeros(
+ input_batch.num_reqs, dtype=torch.int32, device=self.device
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
@@ -205,7 +229,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=input_batch.num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
- num_computed_tokens_cpu=num_computed_tokens_cpu,
+ seq_lens_np=input_batch.seq_lens_np,
+ num_computed_tokens_cpu=num_computed_tokens,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
@@ -270,6 +295,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
+ @torch.inference_mode()
+ def _dummy_speculator_run(
+ self,
+ hidden_states: torch.Tensor,
+ aux_hidden_states: list[torch.Tensor] | None,
+ ) -> None:
+ num_tokens = hidden_states.shape[0]
+ num_reqs = min(num_tokens, self.max_num_reqs)
+ input_batch = InputBatch.make_dummy(
+ num_reqs=num_reqs,
+ num_tokens=num_tokens,
+ input_buffers=self.input_buffers,
+ device=self.device,
+ )
+ sampling_metadata = SamplingMetadata.make_dummy(
+ num_reqs=num_reqs,
+ device=self.device,
+ )
+ num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
+ num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
+ self.propose_draft(
+ input_batch=input_batch,
+ sampling_metadata=sampling_metadata,
+ last_hidden_states=hidden_states,
+ aux_hidden_states=aux_hidden_states,
+ num_sampled=num_sampled,
+ num_rejected=num_rejected,
+ )
+
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
@@ -277,6 +331,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
skip_attn=True,
)
self._dummy_sampler_run(sample_hidden_states)
+ if self.do_spec_decode:
+ self._dummy_speculator_run(hidden_states, None)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
@@ -298,6 +354,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return 0
start_time = time.perf_counter()
+ gc.collect()
+ torch.cuda.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
@@ -367,6 +425,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
+ # Update the GPU tensors for request states.
+ if scheduler_output.scheduled_new_reqs:
+ self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -417,48 +478,96 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
+ # Get the number of draft tokens for each request.
+ if not scheduler_output.scheduled_spec_decode_tokens:
+ # No draft token scheduled (common case).
+ total_num_draft_tokens = 0
+ total_num_logits = num_reqs
+ cu_num_logits = torch.arange(
+ num_reqs + 1, device=self.device, dtype=torch.int32
+ )
+ else:
+ draft_tokens = scheduler_output.scheduled_spec_decode_tokens
+ num_draft_tokens = np.array(
+ [
+ len(draft_tokens[req_id]) if req_id in draft_tokens else 0
+ for req_id in req_ids
+ ],
+ dtype=np.int32,
+ )
+ total_num_draft_tokens = int(num_draft_tokens.sum())
+ total_num_logits = num_reqs + total_num_draft_tokens
+
+ np.cumsum(
+ num_draft_tokens + 1,
+ out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
+ )
+ cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
+
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
- prepare_inputs(
- idx_mapping_np,
- self.req_states.prefill_token_ids,
- self.req_states.num_computed_tokens,
+ # Get query_start_loc.
+ np.cumsum(
num_scheduled_tokens,
- self.input_buffers.input_ids,
- self.input_buffers.positions,
- self.input_buffers.query_start_loc,
- self.input_buffers.seq_lens,
- num_tokens,
+ out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
)
+ # Pad for full CUDA graph mode.
+ # Some attention backends like FA3 require query_start_loc to be non-decreasing.
+ self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
+ self.input_buffers.query_start_loc.copy_to_gpu()
+ query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
+ query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
- query_start_loc = self.input_buffers.query_start_loc
- query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
- query_start_loc_np = query_start_loc.np[: num_reqs + 1]
- seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
- seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
+ # Copy prefill tokens from CPU to GPU.
+ prepare_prefill_inputs(
+ idx_mapping_np,
+ num_scheduled_tokens,
+ query_start_loc_np,
+ self.req_states.prefill_token_ids,
+ self.req_states.num_computed_prefill_tokens,
+ self.input_buffers.input_ids.np,
+ )
+ self.input_buffers.input_ids.copy_to_gpu(num_tokens)
- # Some input token ids are directly read from the last sampled tokens.
- combine_last_token_ids(
+ # Prepare positions and seq_lens.
+ prepare_pos_seq_lens(
+ idx_mapping,
+ query_start_loc_gpu,
+ self.req_states.num_computed_tokens,
+ self.input_buffers.positions,
+ self.input_buffers.seq_lens,
+ )
+ seq_lens = self.input_buffers.seq_lens[:num_reqs]
+
+ # Some input token ids are directly read from the last sampled tokens
+ # and draft tokens. Also, get the logits indices to sample tokens from.
+ logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
- seq_lens_gpu,
- self.req_states.prefill_len.copy_to_gpu(),
+ seq_lens,
+ self.req_states.prefill_len.gpu,
+ self.req_states.draft_tokens,
+ cu_num_logits,
+ total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
- query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
+ query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
)
- num_computed_tokens_cpu = torch.from_numpy(
- self.req_states.num_computed_tokens[idx_mapping_np]
- )
-
- # Logits indices to sample next token from.
- logits_indices = query_start_loc_gpu[1:] - 1
+ # Get num_computed_tokens.
+ # HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
+ # num_computed_tokens_cpu. This works for most cases.
+ num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping]
+ # HACK(woosuk): Only GPU has the exact seq_lens because at this point
+ # CPU does not know how many draft tokens are accepted/rejected in the
+ # previous step. Therefore, we use max_model_len to be safe.
+ # NOTE(woosuk): This only works for FA3 backend.
+ seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
@@ -467,14 +576,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
- num_computed_tokens_cpu=num_computed_tokens_cpu,
+ seq_lens_np=seq_lens_np,
+ num_computed_tokens_cpu=num_computed_tokens,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
- positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
+ positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -483,14 +593,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
+ num_draft_tokens=total_num_draft_tokens,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
- seq_lens=seq_lens_gpu,
+ seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
+ cu_num_logits=cu_num_logits,
)
def sample(
@@ -499,11 +611,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
- ) -> SamplerOutput:
+ ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
+ # TODO(woosuk): Make compatible with spec decoding.
+ assert input_batch.num_draft_tokens == 0
with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask(
logits,
@@ -512,8 +626,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output.grammar_bitmask,
self.input_buffers,
)
+
+ # Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
- return sampler_output
+
+ # Get the number of sampled tokens.
+ prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
+ is_chunked_prefilling = input_batch.seq_lens < prefill_len
+ if input_batch.num_draft_tokens == 0:
+ # No draft tokens (common case).
+ # 0 if chunked-prefilling, 1 if not.
+ num_sampled = (~is_chunked_prefilling).int()
+ num_rejected = torch.zeros_like(num_sampled)
+ else:
+ # Draft tokens for spec decoding.
+ input_ids = input_batch.input_ids[input_batch.logits_indices]
+ sampled_tokens, num_sampled = rejection_sample(
+ sampler_output.sampled_token_ids,
+ input_ids,
+ input_batch.cu_num_logits,
+ self.num_speculative_steps,
+ )
+ num_sampled *= ~is_chunked_prefilling
+ num_rejected = get_num_rejected(
+ input_batch.cu_num_logits,
+ num_sampled,
+ )
+ sampler_output.sampled_token_ids = sampled_tokens
+ # TODO(woosuk): Support logprobs with spec decoding.
+ return sampler_output, num_sampled, num_rejected
def compute_prompt_logprobs(
self,
@@ -526,11 +667,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No request asks for prompt logprobs.
return {}
- num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
- includes_prompt = num_computed_tokens < prompt_lens - 1
+ computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
+ includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
@@ -549,8 +690,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_ids[n - 1] = 0
# Handle chunked prompts.
- seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
- is_prompt_chunked = seq_lens < prompt_lens
+ pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
+ is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
@@ -560,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
- next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
+ next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
@@ -616,48 +757,67 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def postprocess(
self,
- sampler_output: SamplerOutput,
- prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch,
- ) -> AsyncOutput | ModelRunnerOutput:
- # Store the last sampled token ids.
- self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
- sampler_output.sampled_token_ids
- )
- # Get the number of sampled tokens.
- # 0 if chunked-prefilling, 1 if not.
- idx_mapping_np = input_batch.idx_mapping_np
- is_chunked_prefilling = (
- input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
- )
- num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
- # Increment the number of tokens.
- self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
- # Increment the number of computed tokens.
- self.req_states.num_computed_tokens[idx_mapping_np] += (
- input_batch.num_scheduled_tokens
+ sampled_tokens: torch.Tensor,
+ num_sampled: torch.Tensor,
+ num_rejected: torch.Tensor,
+ ) -> None:
+ # Update the number of computed tokens.
+ post_update(
+ input_batch.idx_mapping,
+ self.req_states.num_computed_tokens,
+ self.req_states.last_sampled_tokens,
+ sampled_tokens,
+ num_sampled,
+ num_rejected,
+ input_batch.query_start_loc,
)
- model_runner_output = ModelRunnerOutput(
- req_ids=input_batch.req_ids,
- req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
- sampled_token_ids=None, # type: ignore
- logprobs=None,
- prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
- pooler_output=[],
- kv_connector_output=None,
- num_nans_in_logits=None,
+ # Update the number of computed prefill tokens.
+ idx_mapping_np = input_batch.idx_mapping_np
+ computed_prefill = self.req_states.num_computed_prefill_tokens
+ # TODO(woosuk): Simplify this.
+ computed_prefill[idx_mapping_np] = np.minimum(
+ computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens,
+ self.req_states.prefill_len.np[idx_mapping_np],
)
- async_output = AsyncOutput(
- model_runner_output=model_runner_output,
- sampler_output=sampler_output,
- num_sampled_tokens=num_sampled_tokens,
- copy_stream=self.output_copy_stream,
- copy_event=self.output_copy_event,
+
+ @torch.inference_mode()
+ def propose_draft(
+ self,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ last_hidden_states: torch.Tensor,
+ aux_hidden_states: list[torch.Tensor] | None,
+ num_sampled: torch.Tensor,
+ num_rejected: torch.Tensor,
+ ) -> torch.Tensor:
+ num_reqs = input_batch.num_reqs
+ idx_mapping_np = input_batch.idx_mapping_np
+ with async_barrier(self.spec_decode_event):
+ self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
+ self.req_states.prefill_token_ids[
+ idx_mapping_np,
+ self.req_states.num_computed_prefill_tokens[idx_mapping_np],
+ ]
+ )
+ next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
+ num_reqs
+ )
+
+ assert self.speculator is not None
+ draft_tokens = self.speculator.propose(
+ input_batch,
+ sampling_metadata,
+ last_hidden_states,
+ aux_hidden_states,
+ num_sampled,
+ num_rejected,
+ self.req_states.last_sampled_tokens,
+ next_prefill_tokens,
)
- if self.use_async_scheduling:
- return async_output
- return async_output.get_output()
+ self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
+ return draft_tokens
def get_cudagraph_and_dp_padding(
self,
@@ -750,6 +910,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np, pos
)
+ if input_batch.num_draft_tokens > 0:
+ sampling_metadata = self.req_states.expand_sampling_metadata(
+ sampling_metadata, input_batch.cu_num_logits
+ )
if self.lora_config:
# Activate LoRA adapters.
@@ -781,6 +945,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
else:
# Run PyTorch model in eager mode.
+ # TODO(woosuk): Support piecewise CUDA graph.
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
@@ -806,13 +971,50 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
- sampler_output = self.sample(
+ sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
- output = self.postprocess(
- sampler_output,
- prompt_logprobs_dict,
- input_batch,
+
+ # Prepare the model runner output.
+ model_runner_output = ModelRunnerOutput(
+ req_ids=input_batch.req_ids,
+ # NOTE(woosuk): req_id_to_index is unused in this model runner.
+ # Only for compatibility with the existing model runner and scheduler.
+ req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
+ sampled_token_ids=None, # type: ignore
+ logprobs=None,
+ prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
+ pooler_output=[],
+ kv_connector_output=None,
+ num_nans_in_logits=None,
)
- return output
+ async_output = AsyncOutput(
+ model_runner_output=model_runner_output,
+ sampler_output=sampler_output,
+ num_sampled_tokens=num_sampled,
+ copy_stream=self.output_copy_stream,
+ copy_event=self.output_copy_event,
+ )
+
+ # Postprocess results and update request states.
+ # NOTE: This is intentionally done after creating the AsyncOutput,
+ # ensuring that `copy_event` is recorded before calling postprocess.
+ # This sequencing may slightly reduce latency as async D2H copy does not
+ # need to wait for the postprocess to finish.
+ self.postprocess(
+ input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
+ )
+ if self.do_spec_decode:
+ _ = self.propose_draft(
+ input_batch,
+ sampling_metadata,
+ hidden_states,
+ None, # aux_hidden_states
+ num_sampled,
+ num_rejected,
+ )
+
+ if self.use_async_scheduling:
+ return async_output
+ return async_output.get_output()
diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py
index 55f98ca6bb6a3..d8676079ab951 100644
--- a/vllm/v1/worker/gpu/sampler.py
+++ b/vllm/v1/worker/gpu/sampler.py
@@ -3,10 +3,9 @@
from collections.abc import Callable
import torch
-import triton
-import triton.language as tl
from vllm.config.model import LogprobsMode
+from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.states import SamplingMetadata
@@ -69,102 +68,106 @@ class Sampler:
sampled = gumbel_sample(
logits,
- is_greedy,
+ sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
+ apply_temperature=False,
)
return sampled, logits if return_logits else None
@triton.jit
def _gumbel_sample_kernel(
- sampled_ptr,
+ local_argmax_ptr,
+ local_argmax_stride,
+ local_max_ptr,
+ local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
- is_greedy_ptr,
+ temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
+ APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
- is_greedy = tl.load(is_greedy_ptr + req_idx)
+ block_idx = tl.program_id(1)
+ block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = block < vocab_size
+ logits = tl.load(
+ logits_ptr + req_idx * logits_stride + block,
+ mask=mask,
+ other=float("-inf"),
+ )
+ logits = logits.to(tl.float32)
- if is_greedy:
- # Greedy sampling. Don't apply gumbel noise.
- max_val = float("-inf")
- max_idx = 0
- for i in range(0, vocab_size, BLOCK_SIZE):
- block = i + tl.arange(0, BLOCK_SIZE)
- mask = block < vocab_size
- logits = tl.load(
- logits_ptr + req_idx * logits_stride + block,
- mask=mask,
- other=float("-inf"),
- )
-
- idx = tl.argmax(logits, axis=0)
- value = tl.max(logits, axis=0)
- is_greater = value > max_val
- max_val = tl.where(is_greater, value, max_val)
- max_idx = tl.where(is_greater, i + idx, max_idx)
- tl.store(sampled_ptr + req_idx, max_idx)
- return
-
- # Random sampling.
- # Calculate gumbel seed.
- seed = tl.load(seeds_ptr + req_idx)
- pos = tl.load(pos_ptr + req_idx)
- gumbel_seed = tl.randint(seed, pos)
-
- max_val = float("-inf")
- max_idx = 0
- for i in range(0, vocab_size, BLOCK_SIZE):
- block = i + tl.arange(0, BLOCK_SIZE)
- mask = block < vocab_size
+ temp = tl.load(temp_ptr + req_idx).to(tl.float32)
+ if temp != 0.0:
+ # Calculate the seed for gumbel noise.
+ seed = tl.load(seeds_ptr + req_idx)
+ pos = tl.load(pos_ptr + req_idx)
+ gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
r = tl.rand(gumbel_seed, block).to(tl.float64)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
+ # Apply temperature.
+ if APPLY_TEMPERATURE:
+ # NOTE(woosuk): Use div_rn to match the behavior of torch.
+ logits = tl.div_rn(logits, temp)
+
# Apply gumbel noise.
- logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
- # Argmax to get the sampled token.
- idx = tl.argmax(logits, axis=0)
- value = tl.max(logits, axis=0)
- is_greater = value > max_val
- max_val = tl.where(is_greater, value, max_val)
- max_idx = tl.where(is_greater, i + idx, max_idx)
- tl.store(sampled_ptr + req_idx, max_idx)
+ idx = tl.argmax(logits, axis=0)
+ token_id = block_idx * BLOCK_SIZE + idx
+ value = tl.max(logits, axis=0)
+ tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
+ tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
- is_greedy: torch.Tensor, # [num_reqs]
+ temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
+ apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
- # NOTE(woosuk): Use int64 for later indexing.
- sampled = torch.empty(
+ BLOCK_SIZE = 1024
+ num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
+ local_argmax = torch.empty(
num_reqs,
+ num_blocks,
dtype=torch.int64,
device=logits.device,
)
- _gumbel_sample_kernel[(num_reqs,)](
- sampled,
+ local_max = torch.empty(
+ num_reqs,
+ num_blocks,
+ dtype=torch.float32,
+ device=logits.device,
+ )
+ _gumbel_sample_kernel[(num_reqs, num_blocks)](
+ local_argmax,
+ local_argmax.stride(0),
+ local_max,
+ local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
- is_greedy,
+ temperature,
vocab_size,
- num_warps=8,
- BLOCK_SIZE=16384, # type: ignore
+ BLOCK_SIZE=BLOCK_SIZE,
+ APPLY_TEMPERATURE=apply_temperature,
)
+ # NOTE(woosuk): Use int64 for later indexing.
+ max_block_idx = local_max.argmax(dim=-1, keepdim=True)
+ sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled
diff --git a/vllm/v1/worker/gpu/spec_decode/__init__.py b/vllm/v1/worker/gpu/spec_decode/__init__.py
new file mode 100644
index 0000000000000..15b85204e05ce
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/__init__.py
@@ -0,0 +1,18 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+from vllm.config import VllmConfig
+
+
+def init_speculator(
+ vllm_config: VllmConfig,
+ device: torch.device,
+):
+ speculative_config = vllm_config.speculative_config
+ assert speculative_config is not None
+ if speculative_config.use_eagle():
+ from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
+
+ return EagleSpeculator(vllm_config, device)
+ raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py
new file mode 100644
index 0000000000000..3c8621cc69c97
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/eagle.py
@@ -0,0 +1,209 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+import torch.nn as nn
+
+from vllm.config import VllmConfig
+from vllm.config.compilation import CUDAGraphMode
+from vllm.forward_context import set_forward_context
+from vllm.model_executor.model_loader import get_model
+from vllm.triton_utils import tl, triton
+from vllm.v1.worker.gpu.input_batch import InputBatch
+from vllm.v1.worker.gpu.sampler import gumbel_sample
+from vllm.v1.worker.gpu.states import SamplingMetadata
+
+
+class EagleSpeculator:
+ def __init__(self, vllm_config: VllmConfig, device: torch.device):
+ self.vllm_config = vllm_config
+ self.device = device
+
+ self.speculative_config = vllm_config.speculative_config
+ assert self.speculative_config is not None
+ self.method = self.speculative_config.method
+ self.num_speculative_steps = self.speculative_config.num_speculative_tokens
+ self.draft_model_config = self.speculative_config.draft_model_config
+
+ self.scheduler_config = vllm_config.scheduler_config
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
+
+ self.input_ids = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device=device
+ )
+ self.positions = torch.zeros(
+ self.max_num_tokens, dtype=torch.int64, device=device
+ )
+
+ def load_model(self, target_model: nn.Module) -> None:
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("eagle_head"):
+ self.model = get_model(
+ vllm_config=self.vllm_config, model_config=self.draft_model_config
+ )
+
+ share_lm_head = True
+ if share_lm_head and hasattr(target_model, "lm_head"):
+ if hasattr(self.model, "lm_head"):
+ del self.model.lm_head
+ self.model.lm_head = target_model.lm_head
+
+ @torch.inference_mode()
+ def propose(
+ self,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ # [num_tokens, hidden_size]
+ last_hidden_states: torch.Tensor,
+ # num_layers x [num_tokens, hidden_size]
+ aux_hidden_states: list[torch.Tensor] | None,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [max_num_reqs, 1]
+ last_sampled: torch.Tensor,
+ # [num_reqs]
+ next_prefill_tokens: torch.Tensor,
+ ) -> torch.Tensor:
+ # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
+ # number of rejected tokens, we maintain the size of eagle's input_ids and
+ # hidden_states the same as the target model's. This means, we pad each
+ # request's query length to include any rejected positions. By doing so,
+ # we can also reuse the attention metadata (e.g., query_start_loc,
+ # seq_lens) of the target model.
+ if aux_hidden_states:
+ assert self.method == "eagle3"
+ hidden_states = self.model.combine_hidden_states(
+ torch.cat(aux_hidden_states, dim=-1)
+ )
+ else:
+ hidden_states = last_hidden_states
+
+ # Get the input ids and last token indices for the speculator.
+ last_token_indices = prepare_eagle_inputs(
+ self.input_ids,
+ input_batch,
+ num_sampled,
+ num_rejected,
+ last_sampled,
+ next_prefill_tokens,
+ )
+ input_ids = self.input_ids[: input_batch.num_tokens_after_padding]
+
+ # Prefill: Run the eagle speculator with eager mode.
+ with set_forward_context(
+ input_batch.attn_metadata,
+ self.vllm_config,
+ num_tokens=input_batch.num_tokens_after_padding,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ ):
+ ret_hidden_states = self.model(
+ input_ids=input_ids,
+ positions=input_batch.positions,
+ hidden_states=hidden_states,
+ )
+ if self.method == "mtp":
+ last_hidden_states = ret_hidden_states
+ hidden_states = ret_hidden_states
+ else:
+ last_hidden_states, hidden_states = ret_hidden_states
+ sample_hidden_states = last_hidden_states[last_token_indices]
+ logits = self.model.compute_logits(sample_hidden_states)
+
+ num_reqs = input_batch.num_reqs
+ cu_num_logits = input_batch.cu_num_logits[:num_reqs]
+ temperature = sampling_metadata.temperature[cu_num_logits]
+ seed = sampling_metadata.seeds[cu_num_logits]
+ # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
+ # used for draft and target sampling.
+ pos = input_batch.positions[last_token_indices] + 1
+ # NOTE(woosuk): For draft sampling, we only consider the temperature
+ # and ignore the other sampling parameters such as top_k and top_p,
+ # for simplicity and performance.
+ # While this may slightly degrade the acceptance rate, it does not
+ # affect the output distribution after rejection sampling.
+ draft_tokens = gumbel_sample(
+ logits, temperature, seed, pos, apply_temperature=True
+ )
+ if self.num_speculative_steps == 1:
+ # Early exit.
+ return draft_tokens.view(-1, 1)
+ raise NotImplementedError("num_speculative_steps > 1 is not supported yet.")
+
+
+@triton.jit
+def _prepare_eagle_inputs_kernel(
+ last_token_indices_ptr,
+ eagle_input_ids_ptr,
+ target_input_ids_ptr,
+ idx_mapping_ptr,
+ last_sampled_ptr,
+ next_prefill_tokens_ptr,
+ num_sampled_ptr,
+ num_rejected_ptr,
+ query_start_loc_ptr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ query_start = tl.load(query_start_loc_ptr + batch_idx)
+ query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
+ query_len = query_end - query_start
+
+ # Get the true query length and next token after accounting for rejected tokens.
+ num_rejected = tl.load(num_rejected_ptr + batch_idx)
+ query_len -= num_rejected
+
+ num_sampled = tl.load(num_sampled_ptr + batch_idx)
+ if num_sampled > 0:
+ req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
+ next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
+ else:
+ # Chunked prefilling.
+ # Get the next prefill token.
+ next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
+
+ # Shift target_input_ids by one.
+ for i in range(1, query_len, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < query_len
+ input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
+ tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
+
+ last_token_index = query_start + query_len - 1
+ tl.store(last_token_indices_ptr + batch_idx, last_token_index)
+ tl.store(eagle_input_ids_ptr + last_token_index, next_token)
+
+
+def prepare_eagle_inputs(
+ eagle_input_ids: torch.Tensor,
+ input_batch: InputBatch,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [max_num_reqs, 1]
+ last_sampled: torch.Tensor,
+ # [max_num_reqs]
+ next_prefill_tokens: torch.Tensor,
+) -> torch.Tensor:
+ num_reqs = input_batch.num_reqs
+ last_token_indices = torch.empty(
+ num_reqs,
+ dtype=torch.int64,
+ device=eagle_input_ids.device,
+ )
+ _prepare_eagle_inputs_kernel[(num_reqs,)](
+ last_token_indices,
+ eagle_input_ids,
+ input_batch.input_ids,
+ input_batch.idx_mapping,
+ last_sampled,
+ next_prefill_tokens,
+ num_sampled,
+ num_rejected,
+ input_batch.query_start_loc,
+ BLOCK_SIZE=1024,
+ )
+ return last_token_indices
diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
new file mode 100644
index 0000000000000..43c6ac518bccc
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
@@ -0,0 +1,83 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+from vllm.triton_utils import tl, triton
+
+
+@triton.jit
+def _rejection_sample_kernel(
+ sampled_ptr, # [num_reqs, num_speculative_steps + 1]
+ sampled_stride,
+ num_sampled_ptr, # [num_reqs]
+ target_sampled_ptr, # [num_draft_tokens + num_reqs]
+ input_ids_ptr, # [num_draft_tokens + num_reqs]
+ cu_num_logits_ptr, # [num_reqs + 1]
+):
+ req_idx = tl.program_id(0)
+ start_idx = tl.load(cu_num_logits_ptr + req_idx)
+ end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
+ num_tokens = end_idx - start_idx
+
+ num_sampled = 0
+ rejected = False
+ for i in range(num_tokens - 1):
+ if not rejected:
+ target_sampled = tl.load(target_sampled_ptr + start_idx + i)
+ draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
+ tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
+ num_sampled += 1
+ if target_sampled != draft_sampled:
+ rejected = True
+ if not rejected:
+ target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
+ tl.store(
+ sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
+ )
+ num_sampled += 1
+ tl.store(num_sampled_ptr + req_idx, num_sampled)
+
+
+def rejection_sample(
+ # [num_draft_tokens + num_reqs]
+ target_sampled: torch.Tensor,
+ # [num_draft_tokens + num_reqs]
+ input_ids: torch.Tensor,
+ # [num_reqs + 1]
+ cu_num_logits: torch.Tensor,
+ num_speculative_steps: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ num_reqs = cu_num_logits.shape[0] - 1
+ sampled = torch.empty(
+ num_reqs,
+ num_speculative_steps + 1,
+ dtype=target_sampled.dtype,
+ device=target_sampled.device,
+ )
+ num_sampled = torch.empty(
+ num_reqs,
+ dtype=torch.int32,
+ device=target_sampled.device,
+ )
+ _rejection_sample_kernel[(num_reqs,)](
+ sampled,
+ sampled.stride(0),
+ num_sampled,
+ target_sampled,
+ input_ids,
+ cu_num_logits,
+ num_warps=1,
+ )
+ return sampled, num_sampled
+
+
+@torch.compile(dynamic=True)
+def get_num_rejected(
+ cu_num_logits: torch.Tensor,
+ num_sampled: torch.Tensor,
+) -> torch.Tensor:
+ num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
+ num_rejected = num_logits - num_sampled
+ # No token is rejected for chunked prefills.
+ num_rejected *= num_sampled > 0
+ return num_rejected
diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py
index 5d05c3f57790a..513d45d95d7cd 100644
--- a/vllm/v1/worker/gpu/states.py
+++ b/vllm/v1/worker/gpu/states.py
@@ -7,6 +7,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
+from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
@@ -63,6 +64,7 @@ class RequestState:
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
+ num_speculative_steps: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
@@ -70,6 +72,7 @@ class RequestState:
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
+ self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
@@ -85,8 +88,12 @@ class RequestState:
dtype=np.int32,
)
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
- self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
- self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
+
+ # Number of computed tokens.
+ self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
+ self.num_computed_tokens = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, device=device
+ )
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
@@ -96,6 +103,14 @@ class RequestState:
device=device,
)
+ # Draft tokens.
+ self.draft_tokens = torch.zeros(
+ self.max_num_reqs,
+ self.num_speculative_steps,
+ dtype=torch.int64,
+ device=device,
+ )
+
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
@@ -145,7 +160,10 @@ class RequestState:
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
- self.num_tokens[req_idx] = prefill_len
+
+ self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
+ # FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
+ # Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None:
@@ -219,6 +237,17 @@ class RequestState:
max_num_logprobs=max_num_logprobs,
)
+ def expand_sampling_metadata(
+ self,
+ sampling_metadata: SamplingMetadata,
+ cu_num_logits: torch.Tensor,
+ ) -> SamplingMetadata:
+ # For draft tokens, we need to expand the sampling param tensors as
+ # each request samples multiple tokens in each step.
+ return expand_sampling_metadata(
+ sampling_metadata, cu_num_logits, self.num_speculative_steps
+ )
+
def make_lora_inputs(
self,
req_ids: list[str],
@@ -263,3 +292,75 @@ class Param:
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
+
+
+# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
+@triton.jit
+def _expand_sampling_metadata_kernel(
+ temp_ptr,
+ expanded_temp_ptr,
+ top_p_ptr,
+ expanded_top_p_ptr,
+ top_k_ptr,
+ expanded_top_k_ptr,
+ seeds_ptr,
+ expanded_seeds_ptr,
+ cu_num_logits_ptr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+ start_idx = tl.load(cu_num_logits_ptr + req_idx)
+ end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
+ num_tokens = end_idx - start_idx
+
+ block = tl.arange(0, BLOCK_SIZE)
+ mask = block < num_tokens
+
+ temp = tl.load(temp_ptr + req_idx)
+ tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
+
+ if top_p_ptr is not None:
+ top_p = tl.load(top_p_ptr + req_idx)
+ tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
+
+ if top_k_ptr is not None:
+ top_k = tl.load(top_k_ptr + req_idx)
+ tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
+
+ seed = tl.load(seeds_ptr + req_idx)
+ tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
+
+
+def expand_sampling_metadata(
+ sampling_metadata: SamplingMetadata,
+ cu_num_logits: torch.Tensor,
+ num_speculative_steps: int,
+) -> SamplingMetadata:
+ total_num_logits = sampling_metadata.pos.shape[0]
+ create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
+ expanded_temp = create_empty(sampling_metadata.temperature)
+ expanded_top_p = create_empty(sampling_metadata.top_p)
+ expanded_top_k = create_empty(sampling_metadata.top_k)
+ expanded_seeds = create_empty(sampling_metadata.seeds)
+
+ num_reqs = cu_num_logits.shape[0] - 1
+ _expand_sampling_metadata_kernel[(num_reqs,)](
+ sampling_metadata.temperature,
+ expanded_temp,
+ sampling_metadata.top_p,
+ expanded_top_p,
+ sampling_metadata.top_k,
+ expanded_top_k,
+ sampling_metadata.seeds,
+ expanded_seeds,
+ cu_num_logits,
+ BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
+ )
+ return SamplingMetadata(
+ temperature=expanded_temp,
+ top_p=expanded_top_p,
+ top_k=expanded_top_k,
+ seeds=expanded_seeds,
+ pos=sampling_metadata.pos,
+ max_num_logprobs=sampling_metadata.max_num_logprobs,
+ )
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 7b4bc1d2a2241..e7991baeaa1b8 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -43,6 +43,8 @@ class CachedRequestState:
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
+ xdrope_positions: torch.Tensor | None = None
+
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
@@ -219,9 +221,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
- # NOTE(rob): num_prompt_logprobs only includes reqs
- # that are currently in the prefill phase.
- self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -385,12 +384,6 @@ class InputBatch:
if sampling_params.logprobs == -1
else sampling_params.logprobs
)
- if sampling_params.prompt_logprobs is not None:
- self.num_prompt_logprobs[req_id] = (
- self.vocab_size
- if sampling_params.prompt_logprobs == -1
- else sampling_params.prompt_logprobs
- )
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
@@ -488,7 +481,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
- self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
@@ -535,7 +527,7 @@ class InputBatch:
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
- # instead, we need to temporiarily copy the data for one of the indices
+ # instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
@@ -972,10 +964,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
- @property
- def no_prompt_logprob(self) -> bool:
- return not self.num_prompt_logprobs
-
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index e786cd8bc7c97..e78d3c71af77a 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -50,16 +50,21 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+from vllm.model_executor.layers.rotary_embedding import (
+ MRotaryEmbedding,
+ XDRotaryEmbedding,
+)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
SupportsMRoPE,
SupportsMultiModal,
+ SupportsXDRoPE,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription,
+ supports_xdrope,
)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling,
@@ -324,6 +329,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
+ self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@@ -375,7 +381,9 @@ class GPUModelRunner(
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
- self.use_aux_hidden_state_outputs = True
+ self.use_aux_hidden_state_outputs = (
+ self.drafter.eagle3_use_aux_hidden_state
+ )
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
@@ -393,6 +401,9 @@ class GPUModelRunner(
# Request states.
self.requests: dict[str, CachedRequestState] = {}
+ # NOTE(rob): num_prompt_logprobs only includes reqs
+ # that are currently in the prefill phase.
+ self.num_prompt_logprobs: dict[str, int] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
@@ -464,6 +475,7 @@ class GPUModelRunner(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
+ self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
@@ -507,6 +519,13 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ # Similar to mrope but use assigned dimension number for RoPE, 4 as default.
+ self.xdrope_positions = self._make_buffer(
+ (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
+ )
+
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
@@ -588,10 +607,14 @@ class GPUModelRunner(
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(
@@ -687,6 +710,7 @@ class GPUModelRunner(
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
+ self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -755,10 +779,21 @@ class GPUModelRunner(
)
self.requests[req_id] = req_state
+ if sampling_params and sampling_params.prompt_logprobs is not None:
+ self.num_prompt_logprobs[req_id] = (
+ self.input_batch.vocab_size
+ if sampling_params.prompt_logprobs == -1
+ else sampling_params.prompt_logprobs
+ )
+
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._init_xdrope_positions(req_state)
+
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
@@ -974,6 +1009,19 @@ class GPUModelRunner(
)
)
+ def _init_xdrope_positions(self, req_state: CachedRequestState):
+ model = self.get_model()
+ xdrope_model = cast(SupportsXDRoPE, model)
+ assert req_state.prompt_token_ids is not None, (
+ "XD-RoPE requires prompt_token_ids to be available."
+ )
+ assert supports_xdrope(model), "XD-RoPE support is not implemented."
+
+ req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
+ req_state.prompt_token_ids,
+ req_state.mm_features,
+ )
+
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
@@ -1155,21 +1203,35 @@ class GPUModelRunner(
def _get_encoder_seq_lens(
self,
- scheduled_encoder_inputs: dict[str, list[int]],
+ num_scheduled_tokens: dict[str, int],
kv_cache_spec: KVCacheSpec,
num_reqs: int,
- ) -> np.ndarray | None:
+ ) -> tuple[torch.Tensor | None, np.ndarray | None]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
- return None
+ return None, None
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
- encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
- for req_id in scheduled_encoder_inputs:
+ for req_id in num_scheduled_tokens:
req_index = self.input_batch.req_id_to_index[req_id]
- encoder_seq_lens[req_index] = self.max_encoder_len
+ req_state = self.requests[req_id]
+ if req_state.mm_features is None:
+ self.encoder_seq_lens.np[req_index] = 0
+ continue
- return encoder_seq_lens
+ # Get the total number of encoder input tokens for running encoder requests
+ # whether encoding is finished or not so that cross-attention knows how
+ # many encoder tokens to attend to.
+ encoder_input_tokens = sum(
+ feature.mm_position.length for feature in req_state.mm_features
+ )
+ self.encoder_seq_lens.np[req_index] = encoder_input_tokens
+
+ self.encoder_seq_lens.copy_to_gpu(num_reqs)
+ encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
+ encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs]
+
+ return encoder_seq_lens, encoder_seq_lens_cpu
def _prepare_inputs(
self,
@@ -1218,6 +1280,11 @@ class GPUModelRunner(
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
+ # Calculate XD-RoPE positions.
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._calc_xdrope_positions(scheduler_output)
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -1351,6 +1418,12 @@ class GPUModelRunner(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
+ elif self.uses_xdrope_dim > 0:
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
+ self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True,
+ )
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
@@ -1424,7 +1497,7 @@ class GPUModelRunner(
logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False,
for_cudagraph_capture: bool = False,
- scheduled_encoder_inputs: dict[str, list[int]] | None = None,
+ num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
@@ -1489,8 +1562,8 @@ class GPUModelRunner(
for kv_cache_gid, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups
):
- encoder_seq_lens = self._get_encoder_seq_lens(
- scheduled_encoder_inputs or {},
+ encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
+ num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs,
)
@@ -1533,6 +1606,7 @@ class GPUModelRunner(
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=encoder_seq_lens,
+ encoder_seq_lens_cpu=encoder_seq_lens_cpu,
dcp_local_seq_lens=dcp_local_seq_lens,
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
)
@@ -1780,6 +1854,53 @@ class GPUModelRunner(
mrope_pos_ptr += completion_part_len
+ def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
+ xdrope_pos_ptr = 0
+ for index, req_id in enumerate(self.input_batch.req_ids):
+ req = self.requests[req_id]
+ assert req.xdrope_positions is not None
+
+ num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
+ req.prompt_token_ids, req.prompt_embeds
+ )
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's xdrope_positions are pre-computed
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + prompt_part_len
+ src_start = num_computed_tokens
+ src_end = num_computed_tokens + prompt_part_len
+
+ self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
+ :, src_start:src_end
+ ]
+ xdrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's xdrope_positions on-the-fly
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + completion_part_len
+
+ XDRotaryEmbedding.get_next_input_positions_tensor(
+ out=self.xdrope_positions.np,
+ out_offset=dst_start,
+ context_len=num_computed_tokens + prompt_part_len,
+ num_new_tokens=completion_part_len,
+ )
+
+ xdrope_pos_ptr += completion_part_len
+
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
@@ -2024,6 +2145,7 @@ class GPUModelRunner(
req_start_idx = 0
should_sync_mrope_positions = False
+ should_sync_xdrope_positions = False
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
@@ -2097,6 +2219,10 @@ class GPUModelRunner(
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+ if should_sync_xdrope_positions:
+ self._calc_xdrope_positions(scheduler_output)
+ self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+
return mm_embeds, is_mm_embed
def get_model(self) -> nn.Module:
@@ -2371,8 +2497,11 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
+
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
@@ -2466,7 +2595,9 @@ class GPUModelRunner(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
+ logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
+ cu_num_new_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2479,6 +2610,12 @@ class GPUModelRunner(
sampled_token_ids,
self.input_batch.vocab_size,
)
+ if logprobs_tensors:
+ # Needed for extracting logprobs when spec decoding.
+ # This must be done prior to discarding sampled tokens.
+ cu_num_new_tokens = [0]
+ for toks in valid_sampled_token_ids:
+ cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
@@ -2506,10 +2643,6 @@ class GPUModelRunner(
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
- logprobs_tensors = sampler_output.logprobs_tensors
- cu_num_accepted_tokens = (
- [0] if spec_decode_metadata and logprobs_tensors else None
- )
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
@@ -2518,11 +2651,6 @@ class GPUModelRunner(
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
- if cu_num_accepted_tokens is not None:
- cu_num_accepted_tokens.append(
- cu_num_accepted_tokens[-1] + num_sampled_ids
- )
-
if not sampled_ids:
continue
@@ -2544,7 +2672,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
- logprobs_tensors.tolists(cu_num_accepted_tokens)
+ logprobs_tensors.tolists(cu_num_new_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
@@ -2672,7 +2800,7 @@ class GPUModelRunner(
scheduler_output, self.vllm_config
)
if self.cache_config.kv_sharing_fast_prefill:
- assert not self.input_batch.num_prompt_logprobs, (
+ assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
@@ -2716,7 +2844,7 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
- scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
)
@@ -3360,6 +3488,8 @@ class GPUModelRunner(
old_global_expert_indices,
rank_mapping,
)
+ if self.eplb_state.is_async:
+ self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
if (
self.vllm_config.compilation_config.mode
@@ -3437,7 +3567,7 @@ class GPUModelRunner(
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, LogprobsTensors | None]:
- num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
+ num_prompt_logprobs_dict = self.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
@@ -3448,7 +3578,10 @@ class GPUModelRunner(
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
- num_tokens = num_scheduled_tokens[req_id]
+ num_tokens = num_scheduled_tokens.get(req_id)
+ if num_tokens is None:
+ # This can happen if the request was preempted in prefill stage.
+ continue
# Get metadata for this request.
request = self.requests[req_id]
@@ -3629,6 +3762,7 @@ class GPUModelRunner(
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
+ is_graph_capturing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@@ -3727,6 +3861,31 @@ class GPUModelRunner(
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
+ # filter out the valid batch descriptor
+ _cg_mode, batch_descriptor = (
+ self.cudagraph_dispatcher.dispatch(
+ BatchDescriptor(
+ num_tokens=num_tokens_after_padding,
+ uniform_decode=uniform_decode,
+ has_lora=activate_lora and self.lora_config is not None,
+ )
+ )
+ if not is_profile
+ else (CUDAGraphMode.NONE, None)
+ )
+ if cudagraph_runtime_mode is not None:
+ # we allow forcing NONE when the dispatcher disagrees to support
+ # warm ups for cudagraph capture
+ assert (
+ cudagraph_runtime_mode == CUDAGraphMode.NONE
+ or cudagraph_runtime_mode == _cg_mode
+ ), (
+ f"Cudagraph runtime mode mismatch at dummy_run. "
+ f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
+ )
+ else:
+ cudagraph_runtime_mode = _cg_mode
+
attn_metadata: PerLayerAttnMetadata | None = None
# If force_attention is True, we always capture attention. Otherwise,
@@ -3782,6 +3941,8 @@ class GPUModelRunner(
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
else:
positions = self.positions.gpu[:num_tokens_after_padding]
@@ -3801,31 +3962,6 @@ class GPUModelRunner(
num_tokens_after_padding, None, False
)
- # filter out the valid batch descriptor
- _cg_mode, batch_descriptor = (
- self.cudagraph_dispatcher.dispatch(
- BatchDescriptor(
- num_tokens=num_tokens_after_padding,
- uniform_decode=uniform_decode,
- has_lora=activate_lora and self.lora_config is not None,
- )
- )
- if not is_profile
- else (CUDAGraphMode.NONE, None)
- )
- if cudagraph_runtime_mode is not None:
- # we allow forcing NONE when the dispatcher disagrees to support
- # warm ups for cudagraph capture
- assert (
- cudagraph_runtime_mode == CUDAGraphMode.NONE
- or cudagraph_runtime_mode == _cg_mode
- ), (
- f"Cudagraph runtime mode mismatch at dummy_run. "
- f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
- )
- else:
- cudagraph_runtime_mode = _cg_mode
-
if ubatch_slices is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
@@ -3862,7 +3998,7 @@ class GPUModelRunner(
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
use_cudagraphs = (
- cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
+ cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
@@ -3876,6 +4012,7 @@ class GPUModelRunner(
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
+ is_graph_capturing=is_graph_capturing,
)
# This is necessary to avoid blocking DP.
@@ -4108,14 +4245,18 @@ class GPUModelRunner(
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
- # (encode_budget, hidden_size) and scatter encoder
- # output into it.
+ # (max_tokens_for_modality, hidden_size) and scatter
+ # encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
- if encoder_output_shape[0] < encoder_budget:
+ max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
+ dummy_modality
+ ]
+ if encoder_output_shape[0] < max_mm_tokens_per_item:
+ encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
- (encoder_budget, encoder_output_shape[-1])
+ (max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
@@ -4308,6 +4449,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
+ is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)
@@ -4618,7 +4760,7 @@ class GPUModelRunner(
"""
for backend in backends:
is_supported = False
- for supported_size in backend.supported_kernel_block_sizes:
+ for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
@@ -4649,7 +4791,7 @@ class GPUModelRunner(
all_int_supported_sizes = set(
supported_size
for backend in backends
- for supported_size in backend.supported_kernel_block_sizes
+ for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)
diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py
index 6bf4f91931849..2ed65ca9d31cd 100644
--- a/vllm/v1/worker/tpu_input_batch.py
+++ b/vllm/v1/worker/tpu_input_batch.py
@@ -149,9 +149,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
- # NOTE(rob): num_prompt_logprobs only includes reqs
- # that are currently in the prefill phase.
- self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -256,8 +253,6 @@ class InputBatch:
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
- if sampling_params.prompt_logprobs is not None:
- self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
@@ -317,7 +312,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
- self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
@@ -584,10 +578,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
- @property
- def no_prompt_logprob(self) -> bool:
- return not self.num_prompt_logprobs
-
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index 5f6012ec614c2..72d4474b89627 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
+ # NOTE(rob): num_prompt_logprobs only includes reqs
+ # that are currently in the prefill phase.
+ self.num_prompt_logprobs: dict[str, int] = {}
# Initialize input batch early to avoid AttributeError in _update_states
self.input_batch = InputBatch(
@@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
+ self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request=new_req_data.lora_request,
)
+ if sampling_params and sampling_params.prompt_logprobs is not None:
+ self.num_prompt_logprobs[req_id] = (
+ self.input_batch.vocab_size
+ if sampling_params.prompt_logprobs == -1
+ else sampling_params.prompt_logprobs
+ )
+
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.