diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
index d0036f24c8d04..b5f6b2494792f 100755
--- a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
@@ -7,53 +7,51 @@ set -ex
# allow to bind to different cores
CORE_RANGE=${CORE_RANGE:-0-16}
OMP_CORE_RANGE=${OMP_CORE_RANGE:-0-16}
-NUMA_NODE=${NUMA_NODE:-0}
-export CMAKE_BUILD_PARALLEL_LEVEL=32
+export CMAKE_BUILD_PARALLEL_LEVEL=16
# Setup cleanup
remove_docker_container() {
set -e;
- docker rm -f cpu-test-"$NUMA_NODE" || true;
+ docker rm -f cpu-test || true;
}
trap remove_docker_container EXIT
remove_docker_container
# Try building the docker image
-numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu .
+docker build --tag cpu-test --target vllm-test -f docker/Dockerfile.cpu .
-# Run the image, setting --shm-size=4g for tensor parallel.
-docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
+# Run the image
+docker run -itd --cpuset-cpus="$CORE_RANGE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test cpu-test
function cpu_tests() {
set -e
- export NUMA_NODE=$2
- docker exec cpu-test-"$NUMA_NODE" bash -c "
+ docker exec cpu-test bash -c "
set -e
pip list"
# offline inference
- docker exec cpu-test-"$NUMA_NODE" bash -c "
+ docker exec cpu-test bash -c "
set -e
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
# Run kernel tests
- docker exec cpu-test-"$NUMA_NODE" bash -c "
+ docker exec cpu-test bash -c "
set -e
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py"
# basic online serving
- docker exec cpu-test-"$NUMA_NODE" bash -c '
+ docker exec cpu-test bash -c '
set -e
- VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve meta-llama/Llama-3.2-3B-Instruct --max-model-len 2048 &
+ VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
vllm bench serve \
--backend vllm \
--dataset-name random \
- --model meta-llama/Llama-3.2-3B-Instruct \
+ --model Qwen/Qwen3-0.6B \
--num-prompts 20 \
--endpoint /v1/completions
kill -s SIGTERM $server_pid &'
@@ -61,4 +59,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins.
export -f cpu_tests
-timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
+timeout 2h bash -c cpu_tests
diff --git a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
similarity index 82%
rename from .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
rename to .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
index 0d06f53a183d0..6a1bef275d047 100644
--- a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
+++ b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
@@ -1,10 +1,12 @@
#!/usr/bin/env bash
set -euxo pipefail
-# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
+# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] [DATA_PARALLEL_SIZE] [TENSOR_PARALLEL_SIZE]
THRESHOLD=${1:-0.8}
NUM_Q=${2:-1319}
PORT=${3:-8020}
+DATA_PARALLEL_SIZE=${4:-2}
+TENSOR_PARALLEL_SIZE=${5:-2}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
@@ -45,8 +47,10 @@ for BACK in "${BACKENDS[@]}"; do
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
- --tensor-parallel-size 2 \
- --data-parallel-size 2 \
+ --enable-eplb \
+ --eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
+ --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
+ --data-parallel-size ${DATA_PARALLEL_SIZE} \
--enable-expert-parallel \
--trust-remote-code \
--max-model-len 2048 \
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index f098e23866eb3..4ddf11c0b268f 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -1486,4 +1486,4 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 7a46e919f93bf..375645fde7477 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -192,6 +192,7 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
@@ -631,6 +632,7 @@ steps:
# we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
+ - uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min
@@ -902,11 +904,12 @@ steps:
- label: Transformers Nightly Models Test
working_dir: "/vllm-workspace/"
optional: true
+ soft_fail: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)'
+ - pytest -v -s tests/models/test_initialization.py
- pytest -v -s tests/models/test_transformers.py
- # - pytest -v -s tests/models/multimodal/processing/
+ - pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
@@ -1116,6 +1119,7 @@ steps:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
@@ -1340,11 +1344,20 @@ steps:
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010
-- label: Qwen3-30B-A3B-FP8-block Accuracy
+- label: Qwen3-30B-A3B-FP8-block Accuracy (H100)
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
+
+- label: Qwen3-30B-A3B-FP8-block Accuracy (B200)
+ timeout_in_minutes: 60
+ gpu: b200
+ optional: true
+ num_gpus: 2
+ working_dir: "/vllm-workspace"
+ commands:
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
\ No newline at end of file
diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml
index c3e132a536a42..861290ea43c87 100644
--- a/.github/workflows/cleanup_pr_body.yml
+++ b/.github/workflows/cleanup_pr_body.yml
@@ -13,7 +13,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml
index 7d565ef9f2e45..629966b959330 100644
--- a/.github/workflows/issue_autolabel.yml
+++ b/.github/workflows/issue_autolabel.yml
@@ -105,6 +105,31 @@ jobs:
}
],
},
+ cpu: {
+ // Keyword search - matches whole words only (with word boundaries)
+ keywords: [
+ {
+ term: "CPU Backend",
+ searchIn: "title"
+ },
+ {
+ term: "x86",
+ searchIn: "title"
+ },
+ {
+ term: "ARM",
+ searchIn: "title"
+ },
+ {
+ term: "Apple Silicon",
+ searchIn: "title"
+ },
+ {
+ term: "IBM Z",
+ searchIn: "title"
+ },
+ ],
+ },
// Add more label configurations here as needed
// example: {
// keywords: [...],
diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml
index a183033c9adde..3a12c4b3a8300 100644
--- a/.github/workflows/macos-smoke-test.yml
+++ b/.github/workflows/macos-smoke-test.yml
@@ -12,7 +12,7 @@ jobs:
timeout-minutes: 30
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v7
with:
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index e21d13b8161f3..d5e70f30ef638 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -16,7 +16,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: "3.12"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a4cf51d17e982..d88ba3aa66303 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -136,7 +136,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
+ Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
@@ -604,12 +604,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
- "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
+ "csrc/quantization/fp4/nvfp4_experts_quant.cu"
+ "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
+ "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
+ list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
diff --git a/README.md b/README.md
index 033e1035d8916..abbb63158f166 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
*Latest News* 🔥
+- [2025/11] We hosted [vLLM Bangkok Meetup](https://luma.com/v0f647nv). We explored vLLM and LMCache inference and low-resource language adaptation with speakers from Embedded LLM, AMD, and Red Hat. Please find the meetup slides [here](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing).
- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
- [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link).
- [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6).
diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md
index 41e68e047be82..a28c6956be0e9 100644
--- a/benchmarks/kernels/deepgemm/README.md
+++ b/benchmarks/kernels/deepgemm/README.md
@@ -2,7 +2,7 @@
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
-Currently this just includes dense GEMMs and only works on Hopper GPUs.
+Currently, this just includes dense GEMMs and only works on Hopper GPUs.
## Setup
diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu
index 229d9862fb670..27d1e990c611e 100644
--- a/csrc/attention/merge_attn_states.cu
+++ b/csrc/attention/merge_attn_states.cu
@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
- const uint head_size) {
+ const uint head_size, const uint prefix_head_stride,
+ const uint output_head_stride) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
- const uint head_offset =
- token_idx * num_heads * head_size + head_idx * head_size;
- const scalar_t* prefix_head_ptr = prefix_output + head_offset;
- const scalar_t* suffix_head_ptr = suffix_output + head_offset;
- scalar_t* output_head_ptr = output + head_offset;
+ const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
+ head_idx * prefix_head_stride;
+ const uint dst_head_offset = token_idx * num_heads * output_head_stride +
+ head_idx * output_head_stride;
+ const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
+ const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
+ scalar_t* output_head_ptr = output + dst_head_offset;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast(prefix_lse.data_ptr()), \
reinterpret_cast(suffix_output.data_ptr()), \
reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \
- num_heads, head_size); \
+ num_heads, head_size, prefix_head_stride, output_head_stride); \
}
/*@brief Merges the attention states from prefix and suffix
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
+ const uint prefix_head_stride = prefix_output.stride(1);
+ const uint output_head_stride = output.stride(1);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
- TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
- "output heads must be contiguous in memory");
- TORCH_CHECK(
- prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
- "prefix_output heads must be contiguous in memory");
- TORCH_CHECK(
- suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
- "suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr();
diff --git a/csrc/cache.h b/csrc/cache.h
index b162a4a2bc31f..f2a5ec0acf5cd 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt);
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 32960cc8073bb..8a5457206c706 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -905,91 +905,79 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
-template
+template
__global__ void gather_and_maybe_dequant_cache(
- const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
- // ENTRIES...]
- scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
- const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
- const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
- const int32_t block_size, const int32_t entry_size,
+ const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
+ // ENTRIES...]
+ scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
+ const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
+ const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
+ const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
+ const int32_t num_tokens, const int32_t block_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
+ constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
+ using ltype = vllm::vec_n_t;
+ using stype = vllm::vec_n_t;
+ // We are adding this for code readability which will be optimized out when
+ // build in release.
+ assert(CTA_SIZE == blockDim.x);
- const int64_t bid = blockIdx.x; // Batch ID
- const int32_t num_splits = gridDim.y;
- const int32_t split = blockIdx.y;
- const int32_t seq_start = cu_seq_lens[bid];
- const int32_t seq_end = cu_seq_lens[bid + 1];
- const int32_t seq_len = seq_end - seq_start;
- const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
- const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
+#pragma unroll
+ for (int token_id = blockIdx.x; token_id < num_tokens;
+ token_id += gridDim.x) {
+ int64_t batch_id = token_to_seq[token_id];
+ int64_t batch_start = cu_seq_lens[batch_id];
+ int64_t batch_end = cu_seq_lens[batch_id + 1];
+ int32_t batch_offset = token_id - batch_start;
- const int32_t split_start = split * split_blocks;
- const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
+ if (token_id >= batch_end) return;
+ int32_t offset = 0;
+ if (seq_starts != nullptr) {
+ offset = seq_starts[batch_id];
+ }
+ batch_offset += offset;
+ int32_t block_table_id = batch_offset / block_size;
+ int32_t slot_id = batch_offset % block_size;
+ int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
+ int32_t block_id = block_table[block_table_offset];
+ int64_t cache_offset =
+ block_id * cache_block_stride + slot_id * cache_entry_stride;
+ constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
+ scalar_t* dst_ = dst + token_id * dst_entry_stride;
+ cache_t* src_ = const_cast(src_cache) + cache_offset;
- const bool is_active_split = (split_start < tot_blocks);
- const bool is_last_split = (split_end == tot_blocks);
-
- if (!is_active_split) return;
-
- int32_t full_blocks_end = split_end;
- int32_t partial_block_size = 0;
-
- // Adjust the pointer for the block_table for this batch.
- // If seq_starts is provided, compute an offset based on (seq_starts[bid] /
- // page_size)
- const int32_t batch_offset = bid * block_table_stride;
- int32_t offset = 0;
- if (seq_starts != nullptr) {
- offset = seq_starts[bid] / block_size;
- }
- const int32_t* batch_block_table = block_table + batch_offset + offset;
-
- // Adjust dst pointer based on the cumulative sequence lengths.
- dst += seq_start * dst_entry_stride;
-
- if (is_last_split) {
- partial_block_size = seq_len % block_size;
- if (partial_block_size) full_blocks_end -= 1;
- }
-
- auto copy_entry = [&](const cache_t* __restrict__ _src,
- scalar_t* __restrict__ _dst) {
- for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
+#pragma unroll
+ for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
- _dst[i] = static_cast(_src[i]);
+ reinterpret_cast(dst_)[idx] =
+ static_cast(reinterpret_cast(src_)[idx]);
} else {
- _dst[i] =
- fp8::scaled_convert(_src[i], *scale);
+ ltype loaded_val = reinterpret_cast(src_)[idx];
+ stype store_val;
+#pragma unroll
+ for (int j = 0; j < vec_size; ++j) {
+ store_val.val[j] = fp8::scaled_convert(
+ loaded_val.val[j], *scale);
+ }
+ reinterpret_cast(dst_)[idx] = store_val;
}
}
- };
-
- const auto loop_end =
- std::min((int64_t)full_blocks_end, block_table_stride - offset);
- for (int pid = split_start; pid < loop_end; ++pid) {
- auto block_id = batch_block_table[pid];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
- for (int eid = 0; eid < block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
- }
- }
-
- if (partial_block_size) {
- if (offset + full_blocks_end < block_table_stride) {
- auto block_id = batch_block_table[full_blocks_end];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr =
- dst + full_blocks_end * block_size * dst_entry_stride;
- for (int eid = 0; eid < partial_block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
+ // process tail
+ constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
+ dst_ = dst_ + ENTRY_SIZE - tail_cnt;
+ src_ = src_ + ENTRY_SIZE - tail_cnt;
+#pragma unroll
+ for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+ dst_[idx] = static_cast(src_[idx]);
+ } else {
+ dst_[idx] =
+ fp8::scaled_convert(src_[idx], *scale);
}
}
}
@@ -1001,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
-#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
- vllm::gather_and_maybe_dequant_cache \
- <<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst.data_ptr()), \
- block_table.data_ptr(), cu_seq_lens.data_ptr(), \
- block_size, entry_size, block_table_stride, cache_block_stride, \
- cache_entry_stride, dst_entry_stride, \
- reinterpret_cast(scale.data_ptr()), seq_starts_ptr);
+#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
+ vllm::gather_and_maybe_dequant_cache \
+ <<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst.data_ptr()), \
+ block_table.data_ptr(), cu_seq_lens.data_ptr(), \
+ token_to_seq.data_ptr(), num_tokens, block_size, \
+ block_table_stride, cache_block_stride, cache_entry_stride, \
+ dst_entry_stride, reinterpret_cast(scale.data_ptr()), \
+ seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
+// - token_to_seq contains the back mapping from token_id to batch_id
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
- int32_t entry_size = src_cache.flatten(2, -1).size(2);
+ int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
@@ -1038,6 +1030,9 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
+ TORCH_CHECK(head_dim == 576,
+ "gather_and_maybe_dequant_cache only support the head_dim to 576 "
+ "for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
@@ -1055,10 +1050,9 @@ void gather_and_maybe_dequant_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
- // Decide on the number of splits based on the batch size.
- int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
- dim3 grid(batch_size, num_splits);
- dim3 block(1024);
+ constexpr int32_t thread_block_size = 64;
+ dim3 grid(num_tokens);
+ dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr;
diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp
index 12c6f5d3015cc..98f55d7c014be 100644
--- a/csrc/cpu/cpu_attn_impl.hpp
+++ b/csrc/cpu/cpu_attn_impl.hpp
@@ -847,7 +847,7 @@ struct VecTypeTrait {
};
#endif
-#if !defined(__powerpc__)
+#if !defined(__powerpc__) && !defined(__s390x__)
template <>
struct VecTypeTrait {
using vec_t = vec_op::FP16Vec16;
diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp
index 51bca37e699b9..9efd8b7ec14a4 100644
--- a/csrc/cpu/cpu_types_vxe.hpp
+++ b/csrc/cpu/cpu_types_vxe.hpp
@@ -4,6 +4,7 @@
#include
#include
+#include
#include
namespace vec_op {
@@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec {
}
explicit FP32Vec8(const BF16Vec8& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
float reduce_sum() const {
@@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec {
}
FP32Vec8 exp() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::exp(ar.values[0]);
- ret.val[0][1] = std::exp(ar.values[1]);
- ret.val[0][2] = std::exp(ar.values[2]);
- ret.val[0][3] = std::exp(ar.values[3]);
- ret.val[1][0] = std::exp(ar.values[4]);
- ret.val[1][1] = std::exp(ar.values[5]);
- ret.val[1][2] = std::exp(ar.values[6]);
- ret.val[1][3] = std::exp(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ f32x4x2_t out;
+
+ const __vector float log2e = vec_splats(1.44269504088896341f);
+ const __vector float one = vec_splats(1.0f);
+ const __vector float min_x = vec_splats(-87.3f);
+ const __vector float max_x = vec_splats(88.7f);
+
+ // 5th-degree minimax polynomial for 2^r (r in [0,1))
+ const __vector float c1 = vec_splats(0.6931471805599453f);
+ const __vector float c2 = vec_splats(0.240226506959101f);
+ const __vector float c3 = vec_splats(0.05550410866482158f);
+ const __vector float c4 = vec_splats(0.009618129107628477f);
+ const __vector float c5 = vec_splats(0.0013333558146428443f);
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+
+ x = vec_max(x, min_x);
+ x = vec_min(x, max_x);
+
+ __vector float y = vec_mul(x, log2e);
+
+ __vector float kf = vec_floor(y);
+ __vector float r = vec_sub(y, kf);
+
+ __vector signed int k = vec_signed(kf);
+ const __vector signed int min_k = vec_splats((signed int)-126);
+ const __vector signed int max_k = vec_splats((signed int)127);
+ k = vec_min(vec_max(k, min_k), max_k);
+
+ // Build 2^k from exponent bits
+ __vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
+ __vector unsigned int bits = (__vector unsigned int)exp_int;
+ bits = vec_sl(bits, vec_splats((unsigned int)23));
+ __vector float pow2k = (__vector float)bits;
+
+ // Improved minimax polynomial
+ __vector float poly = vec_madd(c5, r, c4);
+ poly = vec_madd(poly, r, c3);
+ poly = vec_madd(poly, r, c2);
+ poly = vec_madd(poly, r, c1);
+ poly = vec_madd(poly, r, one);
+
+ out.val[i] = vec_mul(pow2k, poly);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 tanh() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::tanh(ar.values[0]);
- ret.val[0][1] = std::tanh(ar.values[1]);
- ret.val[0][2] = std::tanh(ar.values[2]);
- ret.val[0][3] = std::tanh(ar.values[3]);
- ret.val[1][0] = std::tanh(ar.values[4]);
- ret.val[1][1] = std::tanh(ar.values[5]);
- ret.val[1][2] = std::tanh(ar.values[6]);
- ret.val[1][3] = std::tanh(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
+ const __vector float one = vec_splats(1.0f);
+ const __vector float two = vec_splats(2.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float sat =
+ vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
+
+ f32x4x2_t out;
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+ __vector float ax = vec_abs(x);
+
+ // sign(x): +1 or -1
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // saturation mask: |x| > sat
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // 2x
+ __vector float two_x = vec_mul(x, two);
+
+ // Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
+ f32x4x2_t tmp;
+ tmp.val[0] = two_x;
+ tmp.val[1] = two_x;
+ FP32Vec8 exp_2x_vec(tmp);
+
+ FP32Vec8 e2x = exp_2x_vec.exp();
+ __vector float e = e2x.reg.val[i];
+
+ // tanh(x) = (e - 1) / (e + 1)
+ __vector float num = vec_sub(e, one);
+ __vector float den = vec_add(e, one);
+
+ __vector float t = vec_div(num, den);
+
+ // For large |x|, clamp to sign(x)
+ out.val[i] = vec_sel(t, sign, saturated);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 er() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::erf(ar.values[0]);
- ret.val[0][1] = std::erf(ar.values[1]);
- ret.val[0][2] = std::erf(ar.values[2]);
- ret.val[0][3] = std::erf(ar.values[3]);
- ret.val[1][0] = std::erf(ar.values[4]);
- ret.val[1][1] = std::erf(ar.values[5]);
- ret.val[1][2] = std::erf(ar.values[6]);
- ret.val[1][3] = std::erf(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // A&S 7.1.26 approximation:
+ // erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
+ // exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
+
+ const __vector float one = vec_splats(1.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float p = vec_splats(0.3275911f);
+
+ // Polynomial coeffs
+ const __vector float a1 = vec_splats(0.254829592f);
+ const __vector float a2 = vec_splats(-0.284496736f);
+ const __vector float a3 = vec_splats(1.421413741f);
+ const __vector float a4 = vec_splats(-1.453152027f);
+ const __vector float a5 = vec_splats(1.061405429f);
+
+ // Threshold where erf(x) ~ sign(x)
+ const __vector float sat = vec_splats(6.0f);
+
+ f32x4x2_t out;
+
+ for (int lane = 0; lane < 2; lane++) {
+ __vector float x = reg.val[lane];
+ __vector float ax = vec_abs(x);
+
+ // sign(x)
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // |x| > 6 → erf(x) = ±1
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // t = 1 / (1 + p * |x|)
+ __vector float t = vec_madd(p, ax, one);
+ t = vec_div(one, t);
+
+ // poly = a5
+ __vector float poly = a5;
+ poly = vec_madd(poly, t, a4);
+ poly = vec_madd(poly, t, a3);
+ poly = vec_madd(poly, t, a2);
+ poly = vec_madd(poly, t, a1);
+
+ // full polynomial: poly = poly * t
+ poly = vec_mul(poly, t);
+
+ // Compute exp(-x^2)
+ __vector float x2 = vec_mul(x, x);
+ __vector float neg_x2 = vec_neg(x2);
+
+ f32x4x2_t tmp;
+ tmp.val[0] = neg_x2;
+ tmp.val[1] = neg_x2;
+ FP32Vec8 exp_neg_x2(tmp);
+
+ FP32Vec8 e = exp_neg_x2.exp();
+ __vector float ex = e.reg.val[lane];
+
+ // erf(x) = sign * (1 - poly * exp(-x^2))
+ __vector float term = vec_mul(poly, ex);
+ __vector float y = vec_sub(one, term);
+ y = vec_mul(y, sign);
+
+ // saturated → ±1
+ __vector float sat_val = vec_mul(sign, one);
+ out.val[lane] = vec_sel(y, sat_val, saturated);
+ }
+
+ return FP32Vec8(out);
+ }
+ // Elementwise sigmoid(x) = 1 / (1 + exp(-x))
+ FP32Vec8 sigmoid() const {
+ const __vector float one = vec_splats(1.0f);
+
+ f32x4x2_t neg;
+ for (int i = 0; i < 2; ++i) {
+ neg.val[i] = vec_neg(reg.val[i]);
+ }
+
+ FP32Vec8 neg_x(neg);
+ FP32Vec8 e = neg_x.exp(); // exp(-x)
+
+ f32x4x2_t denom;
+ for (int i = 0; i < 2; ++i) {
+ denom.val[i] = vec_add(one, e.reg.val[i]);
+ }
+
+ FP32Vec8 denom_vec(denom);
+ FP32Vec8 one_vec(1.0f);
+
+ return one_vec / denom_vec;
+ }
+
+ // Tanh-based GELU:
+ // gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
+ FP32Vec8 gelu_tanh() const {
+ const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
+ const __vector float k_0_0447 = vec_splats(0.044715f);
+
+ f32x4x2_t x2, x3, inner;
+ for (int i = 0; i < 2; ++i) {
+ __vector float x = reg.val[i];
+ x2.val[i] = vec_mul(x, x); // x^2
+ x3.val[i] = vec_mul(x2.val[i], x); // x^3
+ __vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
+ inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
+ }
+
+ FP32Vec8 inner_vec(inner);
+ FP32Vec8 t = inner_vec.tanh(); // tanh part
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ FP32Vec8 x_vec(*this);
+ return x_vec * half_vec * (one_vec + t);
+ }
+
+ // Erf-based GELU:
+ // gelu(x) = 0.5 * x * (1 + erf(x / √2))
+ FP32Vec8 gelu_erf() const {
+ const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
+ FP32Vec8 x_vec(*this);
+
+ f32x4x2_t scaled;
+ for (int i = 0; i < 2; ++i) {
+ scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
+ }
+ FP32Vec8 x_scaled(scaled);
+
+ FP32Vec8 erf_x = x_scaled.er();
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ return x_vec * half_vec * (one_vec + erf_x);
+ }
+
+ // Elementwise reciprocal: 1/x (scalar per lane, for correctness)
+ FP32Vec8 rcp() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / in.values[i];
+ }
+ return FP32Vec8(out.reg);
+ }
+
+ // Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
+ FP32Vec8 rsqrt() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / std::sqrt(in.values[i]);
+ }
+ return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
@@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec {
}
explicit FP32Vec16(const BF16Vec16& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
- reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
- reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
+ reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
+ reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
@@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec {
return result;
}
+ FP32Vec16 max(const FP32Vec16& b) const {
+ return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
+ vec_max(reg.val[1], b.reg.val[1]),
+ vec_max(reg.val[2], b.reg.val[2]),
+ vec_max(reg.val[3], b.reg.val[3])}));
+ }
+
+ float reduce_max() const {
+ AliasReg ar;
+ ar.reg = reg;
+ float result = ar.values[0];
+ unroll_loop([&result, &ar](int i) {
+ if (ar.values[i] > result) result = ar.values[i];
+ });
+ return result;
+ }
+
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
@@ -402,15 +628,14 @@ struct VecType {
using vec_type = BF16Vec8;
};
+// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
+using FP16Vec16 = FP32Vec16;
+
template
void storeFP32(float v, T* ptr) {
*ptr = v;
}
-inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
- acc = acc + a * b;
-}
-
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
@@ -429,6 +654,79 @@ inline void storeFP32(float v, c10::BFloat16* ptr) {
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
+// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
+// intrinsics
+
+// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
+FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_madd(a.reg, b.reg, acc.reg);
+}
+
+// FP32Vec8 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+// FP32Vec16 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Multiply-Subtract: acc = acc - (a * b)
+FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_msub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Add: acc = -(a * b) + acc
+FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Subtract: acc = -(a * b) - acc
+FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
@@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
@@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ __vector unsigned int lsb2 = inp2 >> sh16;
+ __vector unsigned int lsb3 = inp3 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ lsb2 = lsb2 & one;
+ lsb3 = lsb3 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ __vector unsigned int rnd2 = lsb2 + bias;
+ __vector unsigned int rnd3 = lsb3 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
+ inp2 = inp2 + rnd2;
+ inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
@@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
- inp2 = vec_sel(inp2, nan, sel2) >> sh16;
- inp3 = vec_sel(inp3, nan, sel3) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp2 = vec_sel(inp2, nan, sel2);
+ inp3 = vec_sel(inp3, nan, sel3);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+ inp2 = inp2 >> sh16;
+ inp3 = inp3 >> sh16;
+
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
-inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
+// 1D softmax over `n` elements in `input`, writes result to `output`.
+// Uses FP32Vec8 for main body, scalar tail handling.
+// Requirement: n > 0
+FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: find max ----------
+ float max_val = -std::numeric_limits::infinity();
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 v(input + i);
+ FP32Vec8::AliasReg ar;
+ ar.reg = v.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ if (ar.values[j] > max_val) max_val = ar.values[j];
+ }
+ }
+ for (; i < n; ++i) {
+ if (input[i] > max_val) max_val = input[i];
+ }
+
+ // ---------- Pass 2: compute exp(x - max) and sum ----------
+ float sum = 0.0f;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = input[i + j] - max_val;
+ }
+
+ FP32Vec8 v(tmp);
+ FP32Vec8 e = v.exp();
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = e.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ output[i + j] = ar.values[j];
+ sum += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float x = input[i] - max_val;
+ float ex = std::exp(x); // scalar tail
+ output[i] = ex;
+ sum += ex;
+ }
+
+ // ---------- Pass 3: normalize ----------
+ float inv_sum = 1.0f / sum;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = output[i + j] * inv_sum;
+ }
+ FP32Vec8 v(tmp);
+ v.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] *= inv_sum;
+ }
+}
+
+// 1D RMSNorm kernel:
+// input: x[0..n-1]
+// weight: w[0..n-1] (gamma), may be nullptr
+// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
+// eps: small epsilon for numerical stability
+FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
+ const float* weight, int n, float eps) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: compute sum of squares ----------
+ float sum_sq = 0.0f;
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ FP32Vec8 sq = x_vec * x_vec;
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = sq.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ sum_sq += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float v = input[i];
+ sum_sq += v * v;
+ }
+
+ float mean_sq = sum_sq / static_cast(n);
+ float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
+
+ // ---------- Pass 2: scale (and apply weight if given) ----------
+ const float inv_rms_f = inv_rms;
+ i = 0;
+
+ if (weight) {
+ // with gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ float wtmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ wtmp[j] = weight[i + j];
+ }
+ FP32Vec8 w_vec(wtmp);
+
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec * w_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f * weight[i];
+ }
+ } else {
+ // without gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f;
+ }
+ }
+}
+
+// Prefetch data to cache for better memory access performance
+FORCE_INLINE void prefetch(const void* addr) {
+ __builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
+}
}; // namespace vec_op
diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h
index dd86a9a5ba6e9..4dbca30da57a1 100644
--- a/csrc/moe/marlin_moe_wna16/marlin_template.h
+++ b/csrc/moe/marlin_moe_wna16/marlin_template.h
@@ -489,14 +489,16 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
- idx = idx < block_num_valid_tokens ? idx : 0;
- if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
- sh_block_topk_weights[idx] = __hmul2(
- global_scale, Dtype::num2num2(Dtype::float2num(
- topk_weights_ptr[sh_block_sorted_ids[idx]])));
- } else {
- sh_block_topk_weights[idx] = Dtype::num2num2(
- Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
+ if (idx < block_num_valid_tokens) {
+ if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
+ sh_block_topk_weights[idx] =
+ __hmul2(global_scale,
+ Dtype::num2num2(Dtype::float2num(
+ topk_weights_ptr[sh_block_sorted_ids[idx]])));
+ } else {
+ sh_block_topk_weights[idx] = Dtype::num2num2(
+ Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
+ }
}
}
}
diff --git a/csrc/ops.h b/csrc/ops.h
index f8bdc61aaa8ec..4bb7857b15032 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -52,14 +52,13 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
-#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
std::optional output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
-
+#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
index 5b007e5ea3283..6744402783832 100644
--- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
@@ -22,6 +22,7 @@
#include
#include
#include
+#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
@@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
}
template
-void run_fp4_blockwise_scaled_group_mm(
+void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
@@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm(
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
- "Failed to implement GEMM");
+ "Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
- TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
+ TORCH_CHECK(status == cutlass::Status::kSuccess,
+ "Failed to initialize GEMM: status=", (int)status,
+ " workspace_size=", workspace_size, " num_experts=", num_experts,
+ " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
+void run_fp4_blockwise_scaled_group_mm_sm120(
+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
+ int N, int K) {
+ using ProblemShape =
+ cutlass::gemm::GroupProblemShape>;
+ using ElementType = cutlass::float_e2m1_t;
+ using ElementSFType = cutlass::float_ue4m3_t;
+ using ElementA = cutlass::nv_float4_t;
+ using ElementB = cutlass::nv_float4_t;
+
+ // NOTE: For SM120 it seems templating the output type is not supported and
+ // we need to hardcode the output type to bfloat16
+ using ElementC = cutlass::bfloat16_t;
+ using ElementD = ElementC;
+ using ElementAccumulator = float;
+ // Layout definitions
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ using LayoutC = cutlass::layout::RowMajor;
+ using LayoutD = LayoutC;
+
+ // Alignment constraints
+ static constexpr int AlignmentA = 32;
+ static constexpr int AlignmentB = 32;
+ static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value;
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value;
+
+ // Architecture definitions
+ using ArchTag = cutlass::arch::Sm120;
+ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
+
+ using ClusterShape = Shape<_1, _1, _1>;
+ using MmaTileShape = Shape<_128, _128, _128>;
+
+ using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
+ ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
+
+ using CollectiveEpilogue =
+ typename cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, MmaTileShape, ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
+ ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
+ LayoutD*, AlignmentD,
+ cutlass::epilogue::collective::EpilogueScheduleAuto,
+ FusionOperation>::CollectiveOp;
+
+ using CollectiveMainloop =
+ typename cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
+ LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+ using StrideA = typename Gemm::GemmKernel::InternalStrideA;
+ using StrideB = typename Gemm::GemmKernel::InternalStrideB;
+ using StrideC = typename Gemm::GemmKernel::InternalStrideC;
+ using StrideD = typename Gemm::GemmKernel::InternalStrideD;
+
+ using LayoutSFA =
+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
+ using LayoutSFB =
+ typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
+ using ScaleConfig =
+ typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
+
+ using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
+ int num_experts = static_cast(expert_offsets.size(0));
+ auto options_int =
+ torch::TensorOptions().dtype(torch::kInt64).device(a.device());
+
+ torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
+ torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
+ torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
+ torch::Tensor c_strides1 =
+ torch::full({num_experts}, output.stride(0), options_int);
+ torch::Tensor a_strides1 =
+ torch::full({num_experts}, a.stride(0) * 2, options_int);
+ torch::Tensor b_strides1 =
+ torch::full({num_experts}, b.stride(1) * 2, options_int);
+
+ run_get_group_gemm_starts(
+ a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
+ layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
+ expert_offsets, sf_offsets, problem_sizes, M, N, K);
+
+ // Create an instance of the GEMM
+ Gemm gemm_op;
+
+ // Initialize problem_sizes_as_shapes correctly
+ UnderlyingProblemShape* problem_sizes_as_shapes =
+ static_cast(problem_sizes.data_ptr());
+
+ // Set the Scheduler info
+ cutlass::KernelHardwareInfo hw_info;
+ using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
+ typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
+ scheduler.raster_order = RasterOrderOptions::AlongM;
+ hw_info.device_id = a.get_device();
+ static std::unordered_map cached_sm_counts;
+ if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
+ cached_sm_counts[hw_info.device_id] =
+ cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
+ hw_info.device_id);
+ }
+ hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
+
+ // Mainloop Arguments
+ typename GemmKernel::MainloopArguments mainloop_args{
+ static_cast(a_ptrs.data_ptr()),
+ static_cast(a_strides1.data_ptr()),
+ static_cast(b_ptrs.data_ptr()),
+ static_cast(b_strides1.data_ptr()),
+ static_cast(a_scales_ptrs.data_ptr()),
+ reinterpret_cast(layout_sfa.data_ptr()),
+ static_cast(b_scales_ptrs.data_ptr()),
+ reinterpret_cast(layout_sfb.data_ptr())};
+
+ // Epilogue Arguments
+ typename GemmKernel::EpilogueArguments epilogue_args{
+ {}, // epilogue.thread
+ nullptr,
+ static_cast(c_strides1.data_ptr()),
+ static_cast(out_ptrs.data_ptr()),
+ static_cast(c_strides1.data_ptr())};
+ auto& fusion_args = epilogue_args.thread;
+ fusion_args.alpha_ptr_array =
+ reinterpret_cast(alpha_ptrs.data_ptr());
+ fusion_args.dAlpha = {_0{}, _0{}, 1};
+ fusion_args.beta = 0.0f;
+
+ // Gemm Arguments
+ typename GemmKernel::Arguments args{
+ cutlass::gemm::GemmUniversalMode::kGrouped,
+ {num_experts, problem_sizes_as_shapes, nullptr},
+ mainloop_args,
+ epilogue_args,
+ hw_info,
+ scheduler};
+
+ size_t workspace_size = Gemm::get_workspace_size(args);
+ auto const workspace_options =
+ torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
+ auto workspace = torch::empty(workspace_size, workspace_options);
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
+
+ auto can_implement_status = gemm_op.can_implement(args);
+ TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
+ "Failed to implement GEMM: status=", (int)can_implement_status);
+
+ // Run the GEMM
+ auto status = gemm_op.initialize(args, workspace.data_ptr());
+ TORCH_CHECK(status == cutlass::Status::kSuccess,
+ "Failed to initialize GEMM: status=", (int)status,
+ " workspace_size=", workspace_size, " num_experts=", num_experts,
+ " M=", M, " N=", N, " K=", K);
+
+ status = gemm_op.run(args, workspace.data_ptr(), stream);
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
+}
+
+template
+void run_fp4_blockwise_scaled_group_mm(
+ torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
+ const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
+ const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
+ const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
+ int N, int K) {
+ int32_t version_num = get_sm_version_num();
+#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
+ if (version_num >= 120 && version_num < 130) {
+ run_fp4_blockwise_scaled_group_mm_sm120(
+ output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
+ expert_offsets, sf_offsets, M, N, K);
+ return;
+ }
+#endif
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
+ if (version_num >= 100 && version_num < 120) {
+ run_fp4_blockwise_scaled_group_mm_sm100(
+ output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
+ expert_offsets, sf_offsets, M, N, K);
+ return;
+ }
+#endif
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
+ version_num, ". Required capability: 100 or 120");
+}
+
+#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
+ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif
@@ -374,7 +583,8 @@ void cutlass_fp4_group_mm(
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
-#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
+#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
+ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
@@ -408,6 +618,14 @@ void cutlass_fp4_group_mm(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
} else {
+ #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
+ int32_t version_num = get_sm_version_num();
+ if (version_num >= 120 && version_num < 130) {
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
+ output.scalar_type());
+ }
+ #endif
run_fp4_blockwise_scaled_group_mm(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
@@ -416,8 +634,8 @@ void cutlass_fp4_group_mm(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
- "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
- "12.8 or above.");
+ "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
+ "and CUDA 12.8 or above.");
#endif
}
diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu
index 6d385e0dd94e7..82c53c2375a31 100644
--- a/csrc/quantization/fp4/nvfp4_experts_quant.cu
+++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu
@@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
-void scaled_fp4_experts_quant_sm100a(
+void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu
index c2b39e5438805..fb6d22f035b99 100644
--- a/csrc/quantization/fp4/nvfp4_quant_entry.cu
+++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu
@@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input_sf);
#endif
-#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
-void scaled_fp4_experts_quant_sm100a(
+#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
+ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
+void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
@@ -54,8 +55,9 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
-#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
- return scaled_fp4_experts_quant_sm100a(
+#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
+ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
+ return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
index 1001af05ff003..c5012a8669317 100644
--- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
@@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional const& bias);
#endif
-#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
- defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
- defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
+#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
-#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
@@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
- version_num, ". Required capability: 90 or 100");
+ version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes(
@@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes(
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional& blockscale_offsets) {
int32_t version_num = get_sm_version_num();
-#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k,
blockscale_offsets);
@@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes(
false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ",
- version_num, ". Required capability: 90 or 100");
+ version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
@@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
-#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
+ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
@@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
- version_num, ". Required capability: 90 or 100");
+ version_num, ". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 5af74c2c2a6b0..e9c96bb8b56cf 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
-#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
@@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
-
+#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
@@ -695,7 +694,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
- " int batch_size, "
+ " Tensor token_to_seq, "
+ " int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 1b937bbc1225e..eb7c105071c00 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -20,8 +20,8 @@ ARG PYTHON_VERSION=3.12
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
-# TODO: Restore to base image after FlashInfer AOT wheel fixed
-ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
+# Using cuda base image with minimal dependencies necessary for JIT compilation (FlashInfer, DeepGEMM, EP kernels)
+ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04
# By parameterizing the Deadsnakes repository URL, we allow third-party to use
# their own mirror. When doing so, we don't benefit from the transparent
@@ -85,7 +85,20 @@ ARG GET_PIP_URL
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
- && apt-get install -y ccache software-properties-common git curl sudo python3-pip libibverbs-dev \
+ && apt-get install -y --no-install-recommends \
+ ccache \
+ software-properties-common \
+ git \
+ curl \
+ sudo \
+ python3-pip \
+ libibverbs-dev \
+ # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
+ # as it was causing spam when compiling the CUTLASS kernels
+ gcc-10 \
+ g++-10 \
+ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 \
+ && rm -rf /var/lib/apt/lists/* \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
@@ -110,10 +123,6 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
-# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
-# as it was causing spam when compiling the CUTLASS kernels
-RUN apt-get install -y gcc-10 g++-10
-RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
RUN < torch_build_versions.txt
@@ -233,11 +205,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system vllm-dist/*.whl --verbose
-# install xformers again for the new environment
-RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
- --mount=type=cache,target=/root/.cache/uv \
- uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
-
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
# install package for build flashinfer
@@ -307,7 +274,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
# Logging to confirm the torch versions
-RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
+RUN pip freeze | grep -E 'torch|vllm|flashinfer'
# Logging to confirm all the packages are installed
RUN pip freeze
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 5d5b82c4fa5af..adac43c6accbe 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -1,4 +1,4 @@
-FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base
+FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
@@ -25,10 +25,14 @@ RUN apt clean && apt-get update -y && \
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1
-RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing
+RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc
+
+# This oneccl contains the BMG support which is not the case for default version of oneapi 2025.2.
+RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.6/intel-oneccl-2021.15.6.9_offline.sh
+RUN bash intel-oneccl-2021.15.6.9_offline.sh -a --silent --eula accept && \
+ echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc && \
+ echo "source /opt/intel/oneapi/ccl/2021.15/env/vars.sh --force" >> /root/.bashrc
-RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh
-RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc
SHELL ["bash", "-c"]
CMD ["bash", "-c", "source /root/.bashrc && exec bash"]
@@ -72,6 +76,7 @@ RUN python3 -m pip install -e tests/vllm_test_utils
ENV NIXL_VERSION=0.7.0
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
+# remove torch bundled oneccl to avoid conflicts
RUN --mount=type=cache,target=/root/.cache/pip \
pip uninstall oneccl oneccl-devel -y
diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png
index 57a33524a5169..b327eb2151f50 100644
Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ
diff --git a/docs/community/meetups.md b/docs/community/meetups.md
index 0735f452df960..d8cf4ecdd5a32 100644
--- a/docs/community/meetups.md
+++ b/docs/community/meetups.md
@@ -10,6 +10,7 @@ Stay tuned for upcoming meetups! Follow us on [Twitter/X](https://x.com/vllm_pro
Below you'll find slides and recordings from our previous meetups:
+- [vLLM Bangkok Meetup](https://luma.com/v0f647nv), November 21st 2025. [[Slides]](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing)
- [vLLM Zurich Meetup](https://luma.com/0gls27kb), November 6th 2025. [[Slides]](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) [[Recording]](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w), November 1st 2025. [[Slides]](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link)
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg), October 25th 2025. [[Slides]](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6)
diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md
index 09fd85a466eed..735bb2e205332 100644
--- a/docs/contributing/ci/update_pytorch_version.md
+++ b/docs/contributing/ci/update_pytorch_version.md
@@ -98,21 +98,6 @@ to warm it up so that future builds are faster.
-## Update dependencies
-
-Several vLLM dependencies like xFormers depend on PyTorch and need
-to be updated accordingly. Rather than waiting for all of them to publish new
-releases (which would take too much time), they can be built from
-source to unblock the update process.
-
-### xFormers
-
-```bash
-export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
-MAX_JOBS=16 uv pip install --system \
- --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
-```
-
## Update all the different vLLM platforms
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md
index e828de0adf3c2..a68d1f0162a10 100644
--- a/docs/contributing/model/basic.md
+++ b/docs/contributing/model/basic.md
@@ -29,7 +29,7 @@ The initialization code should look like this:
```python
from torch import nn
from vllm.config import VllmConfig
- from vllm.attention import Attention
+ from vllm.attention.layer import Attention
class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
diff --git a/docs/deployment/integrations/kserve.md b/docs/deployment/integrations/kserve.md
index edf79fca4f93e..37b29aa1a4876 100644
--- a/docs/deployment/integrations/kserve.md
+++ b/docs/deployment/integrations/kserve.md
@@ -2,4 +2,4 @@
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
-Please see [this guide](https://kserve.github.io/website/latest/modelserving/v1beta1/llm/huggingface/) for more details on using vLLM with KServe.
+Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md
index 66bf3b27d1f52..7baadf8ba23cb 100644
--- a/docs/design/cuda_graphs.md
+++ b/docs/design/cuda_graphs.md
@@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
```python
class BatchDescriptor(NamedTuple):
num_tokens: int
- uniform_decode: bool = False
+ num_reqs: int
+ uniform: bool = False
+ has_lora: bool = False
```
-where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
+where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
-The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
+The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
!!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
diff --git a/docs/design/debug_vllm_compile.md b/docs/design/debug_vllm_compile.md
index 8912eb58f8ac7..408d2878309dd 100644
--- a/docs/design/debug_vllm_compile.md
+++ b/docs/design/debug_vllm_compile.md
@@ -151,6 +151,76 @@ To avoid this, please either:
2. wrap the branching logic into a custom operator. TorchDynamo does not
trace into custom operators.
+## Debugging constraint violations and dynamic shapes guards issues
+
+Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
+attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
+These guards typically appear when framework code, custom passes, or user code branches based on
+dynamic shape values.
+
+**Example:**
+
+```python
+if x > 10:
+ # path A
+else:
+ # path B
+```
+
+This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
+
+**vLLM's Assumption:**
+vLLM assumes that all guards added by torch.compile are safe to drop and will not
+constrain the compiled graph to specific input shapes. When this assumption is violated,
+it can cause issues that users need to debug.
+Some side effects that indicates this assumption is violated are runtime errors
+or `ConstraintViolationErrors`.
+
+A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
+a single value. If you encounter a constraint violation error or suspect that a dynamic
+shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
+help debug the issue:
+
+```sh
+# Online - using unbacked mode
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
+
+# Online - using backed_size_oblivious mode
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
+```
+
+```py
+# Offline - using unbacked mode
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+LLM(model, compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
+))
+
+# Offline - using backed_size_oblivious mode
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+LLM(model, compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
+))
+```
+
+These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
+
+- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
+- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
+
+For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
+
+### Printing guards
+
+To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
+
+```sh
+TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
+```
+
+Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
+causing guards to be added incorrectly.
+
## Debugging TorchInductor
TorchInductor takes a captured graph and then compiles it down to some Python code
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index f0d5a3e934f39..e54a9e2bc5e77 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
-- [`CompressedTensorsW4A4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4MoeMethod]
+- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
diff --git a/docs/design/optimization_levels.md b/docs/design/optimization_levels.md
new file mode 100644
index 0000000000000..940286071ef3c
--- /dev/null
+++ b/docs/design/optimization_levels.md
@@ -0,0 +1,69 @@
+
+
+# Optimization Levels
+
+## Overview
+
+vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
+
+## Level Summaries and Usage Examples
+```bash
+# CLI usage
+python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O0
+
+# Python API usage
+from vllm.entrypoints.llm import LLM
+
+llm = LLM(
+ model="RedHatAI/Llama-3.2-1B-FP8",
+ optimization_level=0
+)
+```
+
+#### `-O1`: Quick Optimizations
+- **Startup**: Moderate startup time
+- **Performance**: Inductor compilation, CUDAGraphMode.PIECEWISE
+- **Use case**: Balance for most development scenarios
+
+```bash
+# CLI usage
+python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O1
+
+# Python API usage
+from vllm.entrypoints.llm import LLM
+
+llm = LLM(
+ model="RedHatAI/Llama-3.2-1B-FP8",
+ optimization_level=1
+)
+```
+
+#### `-O2`: Full Optimizations (Default)
+- **Startup**: Longer startup time
+- **Performance**: `-O1` + CUDAGraphMode.FULL_AND_PIECEWISE
+- **Use case**: Production workloads where performance is important. This is the default use case. It is also very similar to the previous default. The primary difference is that noop & fusion flags are enabled.
+
+```bash
+# CLI usage (default, so optional)
+python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O2
+
+# Python API usage
+from vllm.entrypoints.llm import LLM
+
+llm = LLM(
+ model="RedHatAI/Llama-3.2-1B-FP8",
+ optimization_level=2 # This is the default
+)
+```
+
+#### `-O3`: Full Optimization
+Still in development. Added infrastructure to prevent changing API in future
+release. Currently behaves the same O2.
+
+## Troubleshooting
+
+### Common Issues
+
+1. **Startup Time Too Long**: Use `-O0` or `-O1` for faster startup
+2. **Compilation Errors**: Use `debug_dump_path` for additional debugging information
+3. **Performance Issues**: Ensure using `-O2` for production
\ No newline at end of file
diff --git a/docs/design/torch_compile.md b/docs/design/torch_compile.md
index 27edc4f89201d..7b0b2c1e96978 100644
--- a/docs/design/torch_compile.md
+++ b/docs/design/torch_compile.md
@@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
+## Dynamic shapes and vllm guard dropping
+
+`torch.compile` is designed to guard on dynamic shapes with no hesitation
+when needed. This contradicts with vLLM's `torch.compile` approach of
+dropping the guards since many of those guards could be material.
+
+`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
+`torch.compile` guards on `backed` dynamic shapes and does not provide a
+guarantee that no guards will be added to them. User code, dynamo,
+inductor, and autograd all can add guards. Moreover, for 0/1
+specializations, backed symbols are specialized unconditionally to 0, 1,
+or >=2 even without encountering a branching on those ranges.
+
+On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
+on and are not 0/1 specialized. However, there is a possibility of
+throwing a data dependent error when a branch that requires their value is
+encountered and no explicit unbacked handling is defined. The framework is
+converging to a state where it won't throw DDE but rather pick general
+paths. One downside of using unbacked is missed optimization opportunities
+due to either perf bugs or picking general paths, also using a fixed
+non-example input-based hint (this will be fixed soon with override_hint
+API). An example of picking general paths is assuming input not contiguous
+in functions call contiguous() and reshape() when can't be symbolically proven
+with a change of introducing a clone.
+
+`backed_size_oblivious` is a flag that enables treating backed symbols as
+unbacked wherever explicit handling for unbacked is defined. With this
+mode, 0/1 specializations are mostly avoided in framework code and the
+default 0/1 specialization does not happen. However, there is still no
+guarantee that torch.compile won't guard, especially due to user code or
+custom passes. `backed_size_oblivious` is experimental in PyTorch compile
+and could be deprecated. That said, it's a safer option to use than
+`backed` and the probability of reducing performance is lower than
+`unbacked`.
+
+### Configuring Dynamic Shapes
+
+The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
+setting the `type` field. You can choose between three modes:
+`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
+
+#### Offline Inference Example (Using LLM class)
+
+When using the `LLM` class for offline inference, you can configure dynamic
+shapes through the `compilation_config` parameter:
+
+```python
+from vllm import LLM, SamplingParams
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+
+# Example: Using backed_size_oblivious (experimental, safer than backed)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B",
+ compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(
+ type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
+ )
+ )
+)
+
+# Example: Using unbacked (strongest guarantee against guards)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B",
+ compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(
+ type=DynamicShapesType.UNBACKED
+ )
+ )
+)
+
+# Generate outputs
+prompts = ["Hello, my name is", "The future of AI is"]
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+outputs = llm.generate(prompts, sampling_params)
+```
+
+#### Online Serving Example (Using vllm serve)
+
+When using `vllm serve` for online serving, you can configure dynamic shapes
+through the `--compilation-config` flag:
+
+```bash
+# Example: Using unbacked
+vllm serve meta-llama/Llama-3.2-1B \
+ --compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'
+
+
+# Alternative: Using dot notation (simpler for single values)
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
+```
+
+#### Choosing the Right Mode
+
+- **BACKED** (default): Use when you're willing to accept potential unsafe dropping of guards
+for maximal performance. Guard could be unsoundly added and then ignored.
+
+- **UNBACKED** Use when you need the strongest guarantee against guards.
+ This is the most conservative option but may miss some optimization opportunities.
+
+- **BACKED_SIZE_OBLIVIOUS**: Use when you want a balance between avoiding guards
+ and performance. This experimental mode is safer than BACKED but still not as
+ conservative as UNBACKED.
+
## Python Code Compilation
In the very verbose logs, we can see:
@@ -122,7 +225,7 @@ When all the shapes are known, `torch.compile` can compare different configs, an
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
- mm 0.0160 ms 81.6%
+ mm 0.0160 ms 81.6%
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md
index 5e86e9388f328..9875bc44c9144 100644
--- a/docs/features/quantization/inc.md
+++ b/docs/features/quantization/inc.md
@@ -22,9 +22,6 @@ export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxab
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
```
-!!! tip
- If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
-
!!! tip
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md
index e38627c707884..7d52891bea7b9 100644
--- a/docs/features/structured_outputs.md
+++ b/docs/features/structured_outputs.md
@@ -7,7 +7,7 @@ This document shows you some examples of the different options that are
available to generate structured outputs.
!!! warning
- If you are still using the following deprecated API fields, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
+ If you are still using the following deprecated API fields which were removed in v0.12.0, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
- `guided_json` -> `{"structured_outputs": {"json": ...}}` or `StructuredOutputsParams(json=...)`
- `guided_regex` -> `{"structured_outputs": {"regex": ...}}` or `StructuredOutputsParams(regex=...)`
diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md
index 9e86f785b10c7..94920dc5306b3 100644
--- a/docs/getting_started/quickstart.md
+++ b/docs/getting_started/quickstart.md
@@ -283,7 +283,7 @@ Currently, vLLM supports multiple backends for efficient Attention computation a
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
-- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`.
+- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
diff --git a/docs/models/hardware_supported_models/cpu.md b/docs/models/hardware_supported_models/cpu.md
index 0832755f8fbe2..811778b2ad529 100644
--- a/docs/models/hardware_supported_models/cpu.md
+++ b/docs/models/hardware_supported_models/cpu.md
@@ -1,25 +1,33 @@
# CPU - Intel® Xeon®
+## Validated Hardware
+
+| Hardware |
+| ----------------------------------------- |
+| [Intel® Xeon® 6 Processors](https://www.intel.com/content/www/us/en/products/details/processors/xeon.html) |
+| [Intel® Xeon® 5 Processors](https://www.intel.com/content/www/us/en/products/docs/processors/xeon/5th-gen-xeon-scalable-processors.html) |
+
## Supported Models
### Text-only Language Models
| Model | Architecture | Supported |
|--------------------------------------|-------------------------------------------|-----------|
-| meta-llama/Llama-3.1 / 3.3 | LlamaForCausalLM | ✅ |
-| meta-llama/Llama-4-Scout | Llama4ForConditionalGeneration | ✅ |
-| meta-llama/Llama-4-Maverick | Llama4ForConditionalGeneration | ✅ |
-| ibm-granite/granite (Granite-MOE) | GraniteMoeForCausalLM | ✅ |
-| Qwen/Qwen3 | Qwen3ForCausalLM | ✅ |
-| zai-org/GLM-4.5 | GLMForCausalLM | ✅ |
-| google/gemma | GemmaForCausalLM | ✅ |
+| meta-llama/Llama-3.1-8B-Instruct | LlamaForCausalLM | ✅ |
+| meta-llama/Llama-3.2-3B-Instruct | LlamaForCausalLM | ✅ |
+| ibm-granite/granite-3.2-2b-instruct | GraniteForCausalLM | ✅ |
+| Qwen/Qwen3-1.7B | Qwen3ForCausalLM | ✅ |
+| Qwen/Qwen3-4B | Qwen3ForCausalLM | ✅ |
+| Qwen/Qwen3-8B | Qwen3ForCausalLM | ✅ |
+| zai-org/glm-4-9b-hf | GLMForCausalLM | ✅ |
+| google/gemma-7b | GemmaForCausalLM | ✅ |
### Multimodal Language Models
| Model | Architecture | Supported |
|--------------------------------------|-------------------------------------------|-----------|
-| Qwen/Qwen2.5-VL | Qwen2VLForConditionalGeneration | ✅ |
-| openai/whisper | WhisperForConditionalGeneration | ✅ |
+| Qwen/Qwen2.5-VL-7B-Instruct | Qwen2VLForConditionalGeneration | ✅ |
+| openai/whisper-large-v3 | WhisperForConditionalGeneration | ✅ |
✅ Runs and optimized.
🟨 Runs and correct but not optimized to green yet.
diff --git a/docs/models/hardware_supported_models/xpu.md b/docs/models/hardware_supported_models/xpu.md
new file mode 100644
index 0000000000000..7b8dcf5c9af26
--- /dev/null
+++ b/docs/models/hardware_supported_models/xpu.md
@@ -0,0 +1,65 @@
+# XPU - Intel® GPUs
+
+## Validated Hardware
+
+| Hardware |
+| ----------------------------------------- |
+| [Intel® Arc™ Pro B-Series Graphics](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/workstations/b-series/overview.html) |
+
+## Supported Models
+
+### Text-only Language Models
+
+| Model | Architecture | FP16 | Dynamic FP8 | MXFP4 |
+| ----------------------------------------- | ---------------------------------------------------- | ---- | ----------- | ----- |
+| openai/gpt-oss-20b | GPTForCausalLM | | | ✅ |
+| openai/gpt-oss-120b | GPTForCausalLM | | | ✅ |
+| deepseek-ai/DeepSeek-R1-Distill-Llama-8B | LlamaForCausalLM | ✅ | ✅ | |
+| deepseek-ai/DeepSeek-R1-Distill-Qwen-14B | QwenForCausalLM | ✅ | ✅ | |
+| deepseek-ai/DeepSeek-R1-Distill-Qwen-32B | QwenForCausalLM | ✅ | ✅ | |
+| deepseek-ai/DeepSeek-R1-Distill-Llama-70B | LlamaForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen2.5-72B-Instruct | Qwen2ForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen3-14B | Qwen3ForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen3-32B | Qwen3ForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen3-30B-A3B | Qwen3ForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen3-30B-A3B-GPTQ-Int4 | Qwen3ForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen3-coder-30B-A3B-Instruct | Qwen3ForCausalLM | ✅ | ✅ | |
+| Qwen/QwQ-32B | QwenForCausalLM | ✅ | ✅ | |
+| deepseek-ai/DeepSeek-V2-Lite | DeepSeekForCausalLM | ✅ | ✅ | |
+| meta-llama/Llama-3.1-8B-Instruct | LlamaForCausalLM | ✅ | ✅ | |
+| baichuan-inc/Baichuan2-13B-Chat | BaichuanForCausalLM | ✅ | ✅ | |
+| THUDM/GLM-4-9B-chat | GLMForCausalLM | ✅ | ✅ | |
+| THUDM/CodeGeex4-All-9B | CodeGeexForCausalLM | ✅ | ✅ | |
+| chuhac/TeleChat2-35B | LlamaForCausalLM (TeleChat2 based on Llama arch) | ✅ | ✅ | |
+| 01-ai/Yi1.5-34B-Chat | YiForCausalLM | ✅ | ✅ | |
+| THUDM/CodeGeex4-All-9B | CodeGeexForCausalLM | ✅ | ✅ | |
+| deepseek-ai/DeepSeek-Coder-33B-base | DeepSeekCoderForCausalLM | ✅ | ✅ | |
+| baichuan-inc/Baichuan2-13B-Chat | BaichuanForCausalLM | ✅ | ✅ | |
+| meta-llama/Llama-2-13b-chat-hf | LlamaForCausalLM | ✅ | ✅ | |
+| THUDM/CodeGeex4-All-9B | CodeGeexForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen1.5-14B-Chat | QwenForCausalLM | ✅ | ✅ | |
+| Qwen/Qwen1.5-32B-Chat | QwenForCausalLM | ✅ | ✅ | |
+
+### Multimodal Language Models
+
+| Model | Architecture | FP16 | Dynamic FP8 | MXFP4 |
+| ---------------------------- | -------------------------------- | ---- | ----------- | ----- |
+| OpenGVLab/InternVL3_5-8B | InternVLForConditionalGeneration | ✅ | ✅ | |
+| OpenGVLab/InternVL3_5-14B | InternVLForConditionalGeneration | ✅ | ✅ | |
+| OpenGVLab/InternVL3_5-38B | InternVLForConditionalGeneration | ✅ | ✅ | |
+| Qwen/Qwen2-VL-7B-Instruct | Qwen2VLForConditionalGeneration | ✅ | ✅ | |
+| Qwen/Qwen2.5-VL-72B-Instruct | Qwen2VLForConditionalGeneration | ✅ | ✅ | |
+| Qwen/Qwen2.5-VL-32B-Instruct | Qwen2VLForConditionalGeneration | ✅ | ✅ | |
+| THUDM/GLM-4v-9B | GLM4vForConditionalGeneration | ✅ | ✅ | |
+| openbmb/MiniCPM-V-4 | MiniCPMVForConditionalGeneration | ✅ | ✅ | |
+
+### Embedding and Reranker Language Models
+
+| Model | Architecture | FP16 | Dynamic FP8 | MXFP4 |
+| ----------------------- | ------------------------------ | ---- | ----------- | ----- |
+| Qwen/Qwen3-Embedding-8B | Qwen3ForTextEmbedding | ✅ | ✅ | |
+| Qwen/Qwen3-Reranker-8B | Qwen3ForSequenceClassification | ✅ | ✅ | |
+
+✅ Runs and optimized.
+🟨 Runs and correct but not optimized to green yet.
+❌ Does not pass accuracy test or does not run.
diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md
index 18bb645ea9a9c..aca865f4bf77d 100644
--- a/docs/models/pooling_models.md
+++ b/docs/models/pooling_models.md
@@ -1,15 +1,15 @@
# Pooling Models
-vLLM also supports pooling models, such as embedding, classification and reward models.
+vLLM also supports pooling models, such as embedding, classification, and reward models.
In vLLM, pooling models implement the [VllmModelForPooling][vllm.model_executor.models.VllmModelForPooling] interface.
These models use a [Pooler][vllm.model_executor.layers.pooler.Pooler] to extract the final hidden states of the input
before returning them.
!!! note
- We currently support pooling models primarily as a matter of convenience. This is not guaranteed to have any performance improvement over using HF Transformers / Sentence Transformers directly.
+ We currently support pooling models primarily for convenience. This is not guaranteed to provide any performance improvements over using Hugging Face Transformers or Sentence Transformers directly.
- We are now planning to optimize pooling models in vLLM. Please comment on if you have any suggestions!
+ We plan to optimize pooling models in vLLM. Please comment on if you have any suggestions!
## Configuration
@@ -19,7 +19,7 @@ Run a model in pooling mode via the option `--runner pooling`.
!!! tip
There is no need to set this option in the vast majority of cases as vLLM can automatically
- detect the model runner to use via `--runner auto`.
+ detect the appropriate model runner via `--runner auto`.
### Model Conversion
@@ -78,7 +78,7 @@ When loading [Sentence Transformers](https://huggingface.co/sentence-transformer
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
You can further customize this via the `--pooler-config` option,
-which takes priority over both the model's and Sentence Transformers's defaults.
+which takes priority over both the model's and Sentence Transformers' defaults.
## Offline Inference
@@ -168,11 +168,11 @@ The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
- - For similarity scores, use `LLM.score(...)`.
+ - For similarity scores, use `LLM.score(...)`.
- For rewards, use `LLM.reward(...)` or `pooling_task="token_classify"`.
- For token classification, use `pooling_task="token_classify"`.
- - For multi-vector retrieval, use `pooling_task="token_embed"`
- - For IO Processor Plugins , use `pooling_task="plugin"`
+ - For multi-vector retrieval, use `pooling_task="token_embed"`.
+ - For IO Processor Plugins, use `pooling_task="plugin"`.
```python
from vllm import LLM
@@ -194,15 +194,15 @@ Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
!!! note
- Please use one of the more specific methods or set the task directly when using [Pooling API](../serving/openai_compatible_server.md#pooling-api) api.:
+ Please use one of the more specific endpoints or set the task directly when using the [Pooling API](../serving/openai_compatible_server.md#pooling-api):
- For embeddings, use [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) or `"task":"embed"`.
- - For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `task":"classify"`.
- - For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api).
- - For rewards, `task":"token_classify"`.
- - For token classification, use `task":"token_classify"`.
- - For multi-vector retrieval, use `task":"token_embed"`
- - For IO Processor Plugins , use `task":"plugin"`
+ - For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `"task":"classify"`.
+ - For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api).
+ - For rewards, use `"task":"token_classify"`.
+ - For token classification, use `"task":"token_classify"`.
+ - For multi-vector retrieval, use `"task":"token_embed"`.
+ - For IO Processor Plugins, use `"task":"plugin"`.
```python
# start a supported embeddings model server with `vllm serve`, e.g.
@@ -232,7 +232,7 @@ for output in response.json()["data"]:
## Matryoshka Embeddings
-[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
+[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows users to trade off between performance and cost.
!!! warning
Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the `dimensions` parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings.
@@ -245,9 +245,9 @@ for output in response.json()["data"]:
### Manually enable Matryoshka Embeddings
-There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.
+There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json`, you can change the output dimension to arbitrary values. Use `matryoshka_dimensions` to control the allowed output dimensions.
-For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": []}` (offline) or `--hf-overrides '{"is_matryoshka": true}'`, `--hf-overrides '{"matryoshka_dimensions": []}'`(online).
+For models that support Matryoshka Embeddings but are not recognized by vLLM, manually override the config using `hf_overrides={"is_matryoshka": True}` or `hf_overrides={"matryoshka_dimensions": []}` (offline), or `--hf-overrides '{"is_matryoshka": true}'` or `--hf-overrides '{"matryoshka_dimensions": []}'` (online).
Here is an example to serve a model with Matryoshka Embeddings enabled.
@@ -278,7 +278,7 @@ A code example can be found here: [examples/offline_inference/pooling/embed_matr
### Online Inference
-Use the following command to start vllm server.
+Use the following command to start the vLLM server.
```bash
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
@@ -310,11 +310,11 @@ An OpenAI client example can be found here: [examples/online_serving/pooling/ope
### Encode task
-We have split the `encode` task into two more specific token wise tasks: `token_embed` and `token_classify`:
+We have split the `encode` task into two more specific token-wise tasks: `token_embed` and `token_classify`:
-- `token_embed` is the same as embed, using normalize as activation.
-- `token_classify` is the same as classify, default using softmax as activation.
+- `token_embed` is the same as `embed`, using normalization as the activation.
+- `token_classify` is the same as `classify`, by default using softmax as the activation.
### Remove softmax from PoolingParams
-We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, you should set `use_activation`, since we actually allow `classify` and `token_classify` to use any activation function.
+We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 404519f887dc6..25579835faf63 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
+| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |
diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md
index 23df3963823aa..e3280bd15b55c 100644
--- a/docs/serving/openai_compatible_server.md
+++ b/docs/serving/openai_compatible_server.md
@@ -49,7 +49,8 @@ We currently support the following OpenAI APIs:
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
- - *Note: `parallel_tool_calls` and `user` parameters are ignored.*
+ - *Note: `user` parameter is ignored.*
+ - *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls.
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.md).
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
old mode 100644
new mode 100755
index 04e6f99f8957e..df6e96ca375fc
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -425,6 +425,13 @@ def parse_args():
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -434,6 +441,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
+ if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {args.tensor_parallel_size}"
+ )
+
audio_count = args.num_audios
req_data = model_example_map[model](
question_per_audio_count[audio_count], audio_count
@@ -446,6 +459,8 @@ def main(args):
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ if args.tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different
diff --git a/examples/offline_inference/qwen3_omni/only_thinker.py b/examples/offline_inference/qwen3_omni/only_thinker.py
new file mode 100644
index 0000000000000..88a61ed694c2e
--- /dev/null
+++ b/examples/offline_inference/qwen3_omni/only_thinker.py
@@ -0,0 +1,170 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+This example shows how to use vLLM for running offline inference
+with the correct prompt format on Qwen2.5-Omni (thinker only).
+"""
+
+from typing import NamedTuple
+
+from vllm import LLM, SamplingParams
+from vllm.assets.audio import AudioAsset
+from vllm.assets.image import ImageAsset
+from vllm.assets.video import VideoAsset
+from vllm.multimodal.image import convert_image_mode
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+
+class QueryResult(NamedTuple):
+ inputs: dict
+ limit_mm_per_prompt: dict[str, int]
+
+
+# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
+# lower-end GPUs.
+# Unless specified, these settings have been tested to work on a single L4.
+
+default_system = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
+ "Group, capable of perceiving auditory and visual inputs, as well as "
+ "generating text and speech."
+)
+
+
+def get_mixed_modalities_query() -> QueryResult:
+ question = (
+ "What is recited in the audio? "
+ "What is the content of this image? Why is this video funny?"
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
+ "<|vision_start|><|image_pad|><|vision_end|>"
+ "<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ "image": convert_image_mode(
+ ImageAsset("cherry_blossom").pil_image, "RGB"
+ ),
+ "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
+ )
+
+
+def get_use_audio_in_video_query() -> QueryResult:
+ question = (
+ "Describe the content of the video in details, then convert what the "
+ "baby say into text."
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ asset = VideoAsset(name="baby_reading", num_frames=16)
+ audio = asset.get_audio(sampling_rate=16000)
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "video": asset.np_ndarrays,
+ "audio": audio,
+ },
+ "mm_processor_kwargs": {
+ "use_audio_in_video": True,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "video": 1},
+ )
+
+
+def get_multi_audios_query() -> QueryResult:
+ question = "Are these two audio clips the same?"
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
+ "<|audio_start|><|audio_pad|><|audio_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "audio": [
+ AudioAsset("winning_call").audio_and_sample_rate,
+ AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ ],
+ },
+ },
+ limit_mm_per_prompt={
+ "audio": 2,
+ },
+ )
+
+
+query_map = {
+ "mixed_modalities": get_mixed_modalities_query,
+ "use_audio_in_video": get_use_audio_in_video_query,
+ "multi_audios": get_multi_audios_query,
+}
+
+
+def main(args):
+ model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+ query_result = query_map[args.query_type]()
+
+ llm = LLM(
+ model=model_name,
+ max_model_len=12800,
+ max_num_seqs=5,
+ limit_mm_per_prompt=query_result.limit_mm_per_prompt,
+ seed=args.seed,
+ )
+
+ # We set temperature to 0.2 so that outputs can be different
+ # even when all prompts are identical when running batch inference.
+ sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
+
+ outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
+
+ for o in outputs:
+ generated_text = o.outputs[0].text
+ print(generated_text)
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(
+ description="Demo on using vLLM for offline inference with "
+ "audio language models"
+ )
+ parser.add_argument(
+ "--query-type",
+ "-q",
+ type=str,
+ default="mixed_modalities",
+ choices=query_map.keys(),
+ help="Query type.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py
index 67a0732459709..29b2e95d262f8 100644
--- a/examples/offline_inference/spec_decode.py
+++ b/examples/offline_inference/spec_decode.py
@@ -133,7 +133,7 @@ def main(args):
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
- gpu_memory_utilization=0.8,
+ gpu_memory_utilization=0.9,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=args.max_model_len,
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
old mode 100644
new mode 100755
index 624de2a2debc3..8f72bf6f0b0d1
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
)
+# HunyuanOCR
+def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "tencent/HunyuanOCR"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+ prompts = [
+ f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>"
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=None,
+ )
+
+
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
def run_hyperclovax_seed_vision(
questions: list[str], modality: str
@@ -1820,6 +1845,7 @@ model_example_map = {
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"h2ovl_chat": run_h2ovl,
+ "hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3,
"interns1": run_interns1,
@@ -2038,6 +2064,13 @@ def parse_args():
help="If True, will send all requests in a second batch with empty mm "
"data to verify cache hits with UUIDs.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -2046,6 +2079,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
+ if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {args.tensor_parallel_size}"
+ )
+
modality = args.modality
mm_input = get_multi_modal_input(args)
data = mm_input["data"]
@@ -2063,6 +2102,8 @@ def main(args):
"seed": args.seed,
"mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
}
+ if args.tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`.
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
old mode 100644
new mode 100755
index d6e169548f15b..7ba4e64b567de
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -1110,6 +1110,7 @@ def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model=model_name,
max_model_len=16384,
max_num_seqs=16,
+ trust_remote_code=True,
limit_mm_per_prompt={"image": len(image_urls)},
)
@@ -1351,10 +1352,18 @@ model_example_map = {
}
-def run_generate(model, question: str, image_urls: list[str], seed: int | None):
+def run_generate(
+ model,
+ question: str,
+ image_urls: list[str],
+ seed: int | None,
+ tensor_parallel_size: int | None,
+):
req_data = model_example_map[model](question, image_urls)
- engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ engine_args = asdict(req_data.engine_args) | {"seed": seed}
+ if tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = tensor_parallel_size
llm = LLM(**engine_args)
sampling_params = SamplingParams(
@@ -1377,7 +1386,13 @@ def run_generate(model, question: str, image_urls: list[str], seed: int | None):
print("-" * 50)
-def run_chat(model: str, question: str, image_urls: list[str], seed: int | None):
+def run_chat(
+ model: str,
+ question: str,
+ image_urls: list[str],
+ seed: int | None,
+ tensor_parallel_size: int | None,
+):
req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory
@@ -1387,6 +1402,8 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
+ if tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = tensor_parallel_size
llm = LLM(**engine_args)
sampling_params = (
@@ -1462,6 +1479,13 @@ def parse_args():
default=2,
help="Number of images to use for the demo.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -1469,13 +1493,20 @@ def main(args: Namespace):
model = args.model_type
method = args.method
seed = args.seed
+ tensor_parallel_size = args.tensor_parallel_size
+
+ if tensor_parallel_size is not None and tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {tensor_parallel_size}"
+ )
image_urls = IMAGE_URLS[: args.num_images]
if method == "generate":
- run_generate(model, QUESTION, image_urls, seed)
+ run_generate(model, QUESTION, image_urls, seed, tensor_parallel_size)
elif method == "chat":
- run_chat(model, QUESTION, image_urls, seed)
+ run_chat(model, QUESTION, image_urls, seed, tensor_parallel_size)
else:
raise ValueError(f"Invalid method: {method}")
diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh
index 1577de85f7ff2..b5c92749466b0 100644
--- a/examples/online_serving/openai_embedding_long_text/service.sh
+++ b/examples/online_serving/openai_embedding_long_text/service.sh
@@ -22,7 +22,6 @@ API_KEY=${API_KEY:-"your-api-key"}
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
export VLLM_ENABLE_CHUNKED_PROCESSING=true
export CUDA_VISIBLE_DEVICES=2,3,4,5
-# export VLLM_ATTENTION_BACKEND=XFORMERS
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
echo "=================================================================="
diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt
index 81d429a5e5f8d..0c6fdd3b33cd1 100644
--- a/requirements/cpu-build.txt
+++ b/requirements/cpu-build.txt
@@ -4,9 +4,9 @@ packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools-scm>=8
--extra-index-url https://download.pytorch.org/whl/cpu
-torch==2.8.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
+torch==2.9.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.0; platform_system == "Darwin"
-torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
+torch==2.9.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL)
wheel
jinja2>=3.1.6
diff --git a/requirements/cpu.txt b/requirements/cpu.txt
index e23d3286f3f78..8c04d6d5ce1b0 100644
--- a/requirements/cpu.txt
+++ b/requirements/cpu.txt
@@ -7,17 +7,17 @@ numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative d
packaging>=24.2
setuptools>=77.0.3,<81.0.0
--extra-index-url https://download.pytorch.org/whl/cpu
-torch==2.8.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
+torch==2.9.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.0; platform_system == "Darwin"
-torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
+torch==2.9.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x"
-torchaudio==2.8.0; platform_machine == "ppc64le"
+torchaudio==2.9.0; platform_machine == "ppc64le"
# required for the image processor of phi3v, this must be updated alongside torch
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x"
-torchvision==0.23.0; platform_machine == "ppc64le"
+torchvision==0.24.0; platform_machine == "ppc64le"
datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs
diff --git a/requirements/cuda.txt b/requirements/cuda.txt
index d63fe9e1e77c1..462f18ef7159b 100644
--- a/requirements/cuda.txt
+++ b/requirements/cuda.txt
@@ -9,6 +9,5 @@ torch==2.9.0
torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
-xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
-flashinfer-python==0.5.2
+flashinfer-python==0.5.3
diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt
index b1f3269cd3813..083230c171096 100644
--- a/requirements/kv_connectors.txt
+++ b/requirements/kv_connectors.txt
@@ -1,2 +1,2 @@
lmcache
-nixl >= 0.6.0 # Required for disaggregated prefill
+nixl >= 0.7.1 # Required for disaggregated prefill
diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt
index d9c5d89c1d52f..53b012372be8e 100644
--- a/requirements/nightly_torch_test.txt
+++ b/requirements/nightly_torch_test.txt
@@ -29,7 +29,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
mteb>=1.38.11, <2 # required for mteb test
-transformers==4.57.1
+transformers==4.57.3
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
# quantization
diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt
index 2d57e7e167869..8a91b59de6f72 100644
--- a/requirements/rocm-test.txt
+++ b/requirements/rocm-test.txt
@@ -45,3 +45,7 @@ multiprocess==0.70.16
# Plugins test
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
+torchgeo==0.7.0
+
+# Required for suffix decoding test
+arctic-inference == 0.1.1
diff --git a/requirements/test.in b/requirements/test.in
index 05f6bcca5c2c4..da7a7db1f00c9 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -37,7 +37,7 @@ datamodel_code_generator # required for minicpm3 test
# TODO: Use lm-eval[api]==0.4.10 once released
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
mteb[bm25s]>=2, <3 # required for mteb test
-transformers==4.57.1
+transformers==4.57.3
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
# quantization
diff --git a/requirements/test.txt b/requirements/test.txt
index bcd511660f85e..c5f103b8b0d78 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1196,7 +1196,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
-transformers==4.57.1
+transformers==4.57.3
# via
# -r requirements/test.in
# genai-perf
diff --git a/requirements/tpu.txt b/requirements/tpu.txt
index 4241cbb2b0333..e6fff58f7b794 100644
--- a/requirements/tpu.txt
+++ b/requirements/tpu.txt
@@ -12,6 +12,4 @@ ray[data]
setuptools==78.1.0
nixl==0.3.0
tpu_info==0.4.0
-
-# Install torch_xla
-torch_xla[tpu, pallas]==2.8.0
\ No newline at end of file
+tpu-inference==0.11.1
diff --git a/requirements/xpu.txt b/requirements/xpu.txt
index 59ea710684a2c..c1dc4195b5231 100644
--- a/requirements/xpu.txt
+++ b/requirements/xpu.txt
@@ -10,9 +10,9 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.61.2 # Required for N-gram speculative decoding
-torch==2.8.0+xpu
+--extra-index-url=https://download.pytorch.org/whl/xpu
+torch==2.9.0+xpu
torchaudio
torchvision
---extra-index-url=https://download.pytorch.org/whl/xpu
-intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post1%2Bxpu-cp312-cp312-linux_x86_64.whl
+intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.9.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl
diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py
index 0cf1e85d4e8ee..521d6c33dd390 100644
--- a/tests/basic_correctness/test_basic_correctness.py
+++ b/tests/basic_correctness/test_basic_correctness.py
@@ -74,9 +74,6 @@ def test_models(
model_executor: str,
enable_prompt_embeds: bool,
) -> None:
- if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
- pytest.skip(f"{backend} does not support gemma2 with full context length.")
-
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend)
diff --git a/tests/compile/fullgraph/test_simple.py b/tests/compile/fullgraph/test_simple.py
index e258133ab50a7..36cc1510ed798 100644
--- a/tests/compile/fullgraph/test_simple.py
+++ b/tests/compile/fullgraph/test_simple.py
@@ -55,7 +55,7 @@ class SillyModel(nn.Module):
def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
- use_inductor,
+ backend,
expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
@@ -64,7 +64,7 @@ def _run_simple_model(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
- use_inductor=use_inductor,
+ backend=backend,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
@@ -124,14 +124,14 @@ def _run_simple_model(
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
-@pytest.mark.parametrize("use_inductor", [True, False])
+@pytest.mark.parametrize("backend", ["inductor", "eager"])
@torch.inference_mode()
@create_new_process_for_each_test("spawn")
-def test_simple_piecewise_compile(use_inductor):
+def test_simple_piecewise_compile(backend):
_run_simple_model(
splitting_ops=["silly::attention"],
use_inductor_graph_partition=False,
- use_inductor=use_inductor,
+ backend=backend,
# 2 * num_layers + 1
expected_num_piecewise_graphs_seen=5,
# 1 + num_layers
@@ -155,7 +155,7 @@ def test_simple_inductor_graph_partition(monkeypatch):
_run_simple_model(
splitting_ops=["silly::attention"],
use_inductor_graph_partition=True,
- use_inductor=True,
+ backend="inductor",
# Since not splitting at fx graph level
expected_num_piecewise_graphs_seen=1,
# Since not splitting at fx graph level
diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py
index 1e8a882a7f3eb..a9e5ccee520e3 100644
--- a/tests/compile/test_config.py
+++ b/tests/compile/test_config.py
@@ -172,8 +172,8 @@ def test_splitting_ops_dynamic():
config = VllmConfig()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
- assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
- assert not config.compilation_config.splitting_ops_contain_attention()
+ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
+ assert config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
config = VllmConfig(
diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py
new file mode 100644
index 0000000000000..c20aea822fe81
--- /dev/null
+++ b/tests/compile/test_dynamic_shapes_compilation.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import gc
+
+import pytest
+import torch
+
+from vllm import LLM, SamplingParams
+from vllm.config.compilation import CompilationMode, DynamicShapesType
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from vllm.utils.torch_utils import is_torch_equal_or_newer
+
+
+def get_test_models():
+ """Get list of models to test based on PyTorch version"""
+ # TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
+ return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
+
+
+@pytest.mark.parametrize("model_name", get_test_models())
+@pytest.mark.parametrize(
+ "shapes_type",
+ [
+ DynamicShapesType.BACKED,
+ DynamicShapesType.UNBACKED,
+ DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
+ ],
+)
+@pytest.mark.parametrize("use_aot_compile", ["0"])
+@pytest.mark.parametrize("use_bytecode_hook", [True, False])
+@pytest.mark.skipif(
+ not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
+)
+def test_dynamic_shapes_compilation(
+ monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
+):
+ """Test that all dynamic shapes types compile successfully"""
+ print(
+ f"\nTesting model: {model_name} with {shapes_type.name}, "
+ f"AOT compile: {use_aot_compile}, "
+ f"Bytecode hook: {use_bytecode_hook}"
+ )
+ if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
+ pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
+
+ monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
+ monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
+
+ prompt = "Hello, my name is"
+
+ print(f"Testing {shapes_type.name} dynamic shapes...")
+
+ # Initialize the model with specific dynamic shapes configuration
+ model = LLM(
+ model=model_name,
+ compilation_config={
+ "mode": CompilationMode.VLLM_COMPILE,
+ "dynamic_shapes_config": {
+ "type": shapes_type.value,
+ },
+ },
+ )
+
+ output = model.generate(prompt)
+ result = output[0].outputs[0].text
+ # Example of setting the sampling parameters
+ tokenizer = get_tokenizer(model_name)
+ yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
+ no_tokens = tokenizer.encode("no", add_special_tokens=False)
+ allowed_ids = list(set(yes_tokens + no_tokens))
+ sampling_params = SamplingParams(
+ max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
+ )
+
+ output = model.generate(
+ "answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
+ sampling_params=sampling_params,
+ )
+ result = output[0].outputs[0].text
+ assert result == "yes"
+
+ # Clean up GPU memory
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ print("GPU memory cleared")
diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py
index ea61c94953a77..dbe12dc5de705 100644
--- a/tests/compile/test_fusion_attn.py
+++ b/tests/compile/test_fusion_attn.py
@@ -9,8 +9,9 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
-from vllm.attention import Attention, AttentionMetadata
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.layer import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py
index 511e50f5fdc24..5ebb95b6db332 100644
--- a/tests/compile/test_qk_norm_rope_fusion.py
+++ b/tests/compile/test_qk_norm_rope_fusion.py
@@ -5,7 +5,8 @@ import pytest
import torch
from tests.compile.backend import TestBackend
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py
index 9498e75b279b7..781dfd44c1ef6 100644
--- a/tests/distributed/test_eplb_execute.py
+++ b/tests/distributed/test_eplb_execute.py
@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import asyncio
import random
import pytest
import torch
import torch.distributed
-from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
+from vllm.distributed.eplb.rebalance_execute import (
+ move_from_buffer,
+ rearrange_expert_weights_inplace,
+ transfer_layer,
+)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_tp_group,
@@ -231,6 +236,100 @@ def verify_redundant_experts_have_same_weights(
)
+def _test_async_transfer_layer_without_mtp_worker(
+ env,
+ world_size: int,
+ num_layers: int,
+ num_local_experts: int,
+ num_logical_experts: int,
+) -> None:
+ set_env_vars_and_device(env)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
+
+ tp_group = get_tp_group()
+ ep_group = tp_group.device_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ total_physical_experts = world_size * num_local_experts
+ hidden_sizes = [16, 32]
+
+ redundancy_config = create_redundancy_config(
+ num_logical_experts,
+ total_physical_experts,
+ )
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ redundancy_config,
+ )
+
+ new_redundancy_config = create_redundancy_config(
+ num_logical_experts,
+ total_physical_experts,
+ )
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ new_redundancy_config,
+ )
+
+ expert_weights = create_expert_weights(
+ num_layers,
+ num_local_experts,
+ hidden_sizes,
+ ep_rank,
+ device,
+ old_indices,
+ )
+
+ expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
+ cuda_stream = torch.cuda.Stream(device=device)
+
+ for layer_idx in range(num_layers):
+ is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
+ transfer_layer(
+ old_global_expert_indices=old_indices,
+ new_global_expert_indices=new_indices,
+ expert_weights=expert_weights,
+ expert_weights_buffer=expert_buffer,
+ ep_group=ep_group,
+ layer=layer_idx,
+ cuda_stream=cuda_stream,
+ )
+ )
+
+ cuda_stream.synchronize()
+ move_from_buffer(
+ expert_weights=expert_weights[layer_idx],
+ expert_weights_buffer=expert_buffer,
+ is_unchanged=is_unchanged,
+ is_received_locally=is_received_locally,
+ experts_recv_loc=experts_recv_loc,
+ new_indices=new_indices[layer_idx].tolist(),
+ ep_group=ep_group,
+ )
+
+ verify_expert_weights_after_shuffle(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ ep_rank,
+ num_local_experts,
+ )
+ verify_redundant_experts_have_same_weights(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ world_size,
+ num_local_experts,
+ )
+
+
def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts
) -> None:
@@ -399,6 +498,32 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
)
+@pytest.mark.parametrize(
+ "world_size,num_layers,num_local_experts,num_logical_experts",
+ [
+ (2, 2, 2, 3),
+ ],
+)
+def test_async_transfer_layer_without_mtp(
+ world_size: int,
+ num_layers: int,
+ num_local_experts: int,
+ num_logical_experts: int,
+):
+ """Exercise async EPLB transfer path without MTP/spec decode."""
+
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ distributed_run(
+ _test_async_transfer_layer_without_mtp_worker,
+ world_size,
+ num_layers,
+ num_local_experts,
+ num_logical_experts,
+ )
+
+
@pytest.mark.parametrize("world_size", [2, 4])
def test_rearrange_expert_weights_no_change(world_size):
"""
diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py
index 11e23f128f331..c055b7a3f6dd7 100644
--- a/tests/distributed/test_eplb_spec_decode.py
+++ b/tests/distributed/test_eplb_spec_decode.py
@@ -10,10 +10,11 @@ from tests.utils import large_gpu_mark
def get_model_args(
model_name: str,
- spec_model_name: str,
+ spec_model_name: str | None,
spec_method: str,
tp_size: int,
model_max_len: int,
+ use_async: bool = False,
) -> dict:
speculative_config = {
"method": spec_method,
@@ -37,6 +38,8 @@ def get_model_args(
"enable_eplb": True,
"max_model_len": model_max_len,
}
+ if use_async:
+ model_args["eplb_config"] = {"use_async": True}
return model_args
@@ -94,3 +97,37 @@ def test_eplb_spec_decode(
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
+
+
+@large_gpu_mark(min_gb=80)
+def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
+ """
+ Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
+ """
+
+ TASK = "gsm8k"
+ FILTER = "exact_match,strict-match"
+ RTOL = 0.03
+ expected_gsm8k_value = 0.86
+
+ model_args = get_model_args(
+ model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
+ spec_model_name=None,
+ spec_method="mtp",
+ tp_size=4,
+ model_max_len=4096,
+ use_async=True,
+ )
+
+ results = lm_eval.simple_evaluate(
+ model="vllm",
+ model_args=model_args,
+ tasks=TASK,
+ batch_size=64,
+ num_fewshot=8,
+ )
+ measured_value = results["results"][TASK][FILTER]
+ assert (
+ measured_value - RTOL < expected_gsm8k_value
+ and measured_value + RTOL > expected_gsm8k_value
+ ), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index 10827e3b4b9cd..0077609b2f365 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -222,6 +222,47 @@ def test_media_io_kwargs_parser(arg, expected):
assert args.media_io_kwargs == expected
+@pytest.mark.parametrize(
+ ("args", "expected"),
+ [
+ (["-O", "1"], "1"),
+ (["-O", "2"], "2"),
+ (["-O", "3"], "3"),
+ (["-O0"], "0"),
+ (["-O1"], "1"),
+ (["-O2"], "2"),
+ (["-O3"], "3"),
+ ],
+)
+def test_optimization_level(args, expected):
+ """
+ Test space-separated optimization levels (-O 1, -O 2, -O 3) map to
+ optimization_level.
+ """
+ parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
+ parsed_args = parser.parse_args(args)
+ assert parsed_args.optimization_level == expected
+ assert parsed_args.compilation_config.mode is None
+
+
+@pytest.mark.parametrize(
+ ("args", "expected"),
+ [
+ (["-O.mode=0"], 0),
+ (["-O.mode=1"], 1),
+ (["-O.mode=2"], 2),
+ (["-O.mode=3"], 3),
+ ],
+)
+def test_mode_parser(args, expected):
+ """
+ Test compilation config modes (-O.mode=int) map to compilation_config.
+ """
+ parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
+ parsed_args = parser.parse_args(args)
+ assert parsed_args.compilation_config.mode == expected
+
+
def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
@@ -229,34 +270,17 @@ def test_compilation_config():
args = parser.parse_args([])
assert args.compilation_config == CompilationConfig()
- # set to O3
- args = parser.parse_args(["-O0"])
- assert args.compilation_config.mode == 0
-
- # set to O 3 (space)
- args = parser.parse_args(["-O", "1"])
- assert args.compilation_config.mode == 1
-
- # set to O 3 (equals)
- args = parser.parse_args(["-O=2"])
- assert args.compilation_config.mode == 2
-
- # set to O.mode 3
- args = parser.parse_args(["-O.mode", "3"])
- assert args.compilation_config.mode == 3
-
# set to string form of a dict
args = parser.parse_args(
[
"-O",
- '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
- '"use_inductor": false}',
+ '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
]
)
assert (
args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
- and not args.compilation_config.use_inductor
+ and args.compilation_config.backend == "eager"
)
# set to string form of a dict
@@ -264,13 +288,13 @@ def test_compilation_config():
[
"--compilation-config="
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
- '"use_inductor": true}',
+ '"backend": "inductor"}',
]
)
assert (
args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
- and args.compilation_config.use_inductor
+ and args.compilation_config.backend == "inductor"
)
@@ -278,8 +302,9 @@ def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([])
+ # should be None by default (depends on model).
engine_args = EngineArgs.from_cli_args(args=args)
- assert engine_args.enable_prefix_caching, "prefix caching should default to on."
+ assert engine_args.enable_prefix_caching is None
# with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"])
diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py
index 4e7b765d7713f..65a6fd20bd0d1 100644
--- a/tests/entrypoints/openai/test_metrics.py
+++ b/tests/entrypoints/openai/test_metrics.py
@@ -183,9 +183,6 @@ async def test_metrics_counts(
EXPECTED_METRICS_V1 = [
"vllm:num_requests_running",
"vllm:num_requests_waiting",
- "vllm:gpu_cache_usage_perc",
- "vllm:gpu_prefix_cache_queries",
- "vllm:gpu_prefix_cache_hits",
"vllm:kv_cache_usage_perc",
"vllm:prefix_cache_queries",
"vllm:prefix_cache_hits",
diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py
index 6251e1776c30a..8fd3545eccffa 100644
--- a/tests/entrypoints/openai/test_response_api_with_harmony.py
+++ b/tests/entrypoints/openai/test_response_api_with_harmony.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
+import importlib
import json
import time
@@ -35,6 +35,10 @@ GET_WEATHER_SCHEMA = {
@pytest.fixture(scope="module")
def server():
+ assert importlib.util.find_spec("gpt_oss") is not None, (
+ "Harmony tests require gpt_oss package to be installed"
+ )
+
args = ["--enforce-eager", "--tool-server", "demo", "--max_model_len", "5000"]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py
index 88580ed899f1a..8045ab1468d6a 100644
--- a/tests/entrypoints/openai/test_transcription_validation.py
+++ b/tests/entrypoints/openai/test_transcription_validation.py
@@ -2,20 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for structured outputs tests
-import io
import json
-import librosa
-import numpy as np
-import openai
import pytest
-import pytest_asyncio
-import soundfile as sf
from ...utils import RemoteOpenAIServer
-MODEL_NAME = "openai/whisper-large-v3-turbo"
-SERVER_ARGS = ["--enforce-eager"]
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode",
"mistral",
@@ -26,22 +18,8 @@ MISTRAL_FORMAT_ARGS = [
]
-@pytest.fixture(scope="module")
-def server():
- with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
- yield remote_server
-
-
-@pytest_asyncio.fixture
-async def client(server):
- async with server.get_async_client() as async_client:
- yield async_client
-
-
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]
-)
+@pytest.mark.parametrize("model_name", ["mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
server_args = ["--enforce-eager"]
@@ -120,176 +98,3 @@ async def test_basic_audio_gemma(foscolo):
)
out = json.loads(transcription)["text"]
assert "da cui vergine nacque Venere" in out
-
-
-@pytest.mark.asyncio
-async def test_non_asr_model(winning_call):
- # text to text model
- model_name = "JackFram/llama-68m"
- with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
- client = remote_server.get_async_client()
- res = await client.audio.transcriptions.create(
- model=model_name, file=winning_call, language="en", temperature=0.0
- )
- err = res.error
- assert err["code"] == 400 and not res.text
- assert err["message"] == "The model does not support Transcriptions API"
-
-
-@pytest.mark.asyncio
-async def test_bad_requests(mary_had_lamb, client):
- # invalid language
- with pytest.raises(openai.BadRequestError):
- await client.audio.transcriptions.create(
- model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0
- )
-
-
-@pytest.mark.asyncio
-async def test_long_audio_request(mary_had_lamb, client):
- mary_had_lamb.seek(0)
- audio, sr = librosa.load(mary_had_lamb)
- # Add small silence after each audio for repeatability in the split process
- audio = np.pad(audio, (0, 1600))
- repeated_audio = np.tile(audio, 10)
- # Repeated audio to buffer
- buffer = io.BytesIO()
- sf.write(buffer, repeated_audio, sr, format="WAV")
- buffer.seek(0)
- transcription = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=buffer,
- language="en",
- response_format="text",
- temperature=0.0,
- )
- out = json.loads(transcription)
- out_text = out["text"]
- out_usage = out["usage"]
- counts = out_text.count("Mary had a little lamb")
- assert counts == 10, counts
- assert out_usage["seconds"] == 161, out_usage["seconds"]
-
-
-@pytest.mark.asyncio
-async def test_completion_endpoints(client):
- # text to text model
- res = await client.chat.completions.create(
- model=MODEL_NAME,
- messages=[{"role": "system", "content": "You are a helpful assistant."}],
- )
- err = res.error
- assert err["code"] == 400
- assert err["message"] == "The model does not support Chat Completions API"
-
- res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
- err = res.error
- assert err["code"] == 400
- assert err["message"] == "The model does not support Completions API"
-
-
-@pytest.mark.asyncio
-async def test_streaming_response(winning_call, client):
- transcription = ""
- res_no_stream = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=winning_call,
- response_format="json",
- language="en",
- temperature=0.0,
- )
- res = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=winning_call,
- language="en",
- temperature=0.0,
- stream=True,
- timeout=30,
- )
- # Reconstruct from chunks and validate
- async for chunk in res:
- text = chunk.choices[0]["delta"]["content"]
- transcription += text
-
- assert transcription == res_no_stream.text
-
-
-@pytest.mark.asyncio
-async def test_stream_options(winning_call, client):
- res = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=winning_call,
- language="en",
- temperature=0.0,
- stream=True,
- extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True),
- timeout=30,
- )
- final = False
- continuous = True
- async for chunk in res:
- if not len(chunk.choices):
- # final usage sent
- final = True
- else:
- continuous = continuous and hasattr(chunk, "usage")
- assert final and continuous
-
-
-@pytest.mark.asyncio
-async def test_sampling_params(mary_had_lamb, client):
- """
- Compare sampling with params and greedy sampling to assert results
- are different when extreme sampling parameters values are picked.
- """
- transcription = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=mary_had_lamb,
- language="en",
- temperature=0.8,
- extra_body=dict(
- seed=42,
- repetition_penalty=1.9,
- top_k=12,
- top_p=0.4,
- min_p=0.5,
- frequency_penalty=1.8,
- presence_penalty=2.0,
- ),
- )
-
- greedy_transcription = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=mary_had_lamb,
- language="en",
- temperature=0.0,
- extra_body=dict(seed=42),
- )
-
- assert greedy_transcription.text != transcription.text
-
-
-@pytest.mark.asyncio
-async def test_audio_prompt(mary_had_lamb, client):
- prompt = "This is a speech, recorded in a phonograph."
- # Prompts should not omit the part of original prompt while transcribing.
- prefix = "The first words I spoke in the original phonograph"
- transcription = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=mary_had_lamb,
- language="en",
- response_format="text",
- temperature=0.0,
- )
- out = json.loads(transcription)["text"]
- assert prefix in out
- transcription_wprompt = await client.audio.transcriptions.create(
- model=MODEL_NAME,
- file=mary_had_lamb,
- language="en",
- response_format="text",
- prompt=prompt,
- temperature=0.0,
- )
- out_prompt = json.loads(transcription_wprompt)["text"]
- assert prefix in out_prompt
diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py
new file mode 100644
index 0000000000000..82c50e58a0168
--- /dev/null
+++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py
@@ -0,0 +1,237 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# imports for structured outputs tests
+import asyncio
+import io
+import json
+
+import librosa
+import numpy as np
+import openai
+import pytest
+import pytest_asyncio
+import soundfile as sf
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "openai/whisper-large-v3-turbo"
+SERVER_ARGS = ["--enforce-eager"]
+
+
+@pytest.fixture(scope="module")
+def server():
+ with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
+ yield remote_server
+
+
+@pytest_asyncio.fixture
+async def whisper_client(server):
+ async with server.get_async_client() as async_client:
+ yield async_client
+
+
+@pytest.mark.asyncio
+async def test_basic_audio(mary_had_lamb):
+ server_args = ["--enforce-eager"]
+
+ # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
+ with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
+ client = remote_server.get_async_client()
+ transcription = await client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=mary_had_lamb,
+ language="en",
+ response_format="text",
+ temperature=0.0,
+ )
+ out = json.loads(transcription)
+ out_text = out["text"]
+ out_usage = out["usage"]
+ assert "Mary had a little lamb," in out_text
+ assert out_usage["seconds"] == 16, out_usage["seconds"]
+
+
+@pytest.mark.asyncio
+async def test_basic_audio_batched(mary_had_lamb, winning_call, whisper_client):
+ transcription = whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=mary_had_lamb,
+ language="en",
+ response_format="text",
+ temperature=0.0,
+ )
+ transcription2 = whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=winning_call,
+ language="en",
+ response_format="text",
+ temperature=0.0,
+ )
+ # Await both transcriptions by scheduling coroutines together
+ transcription, transcription2 = await asyncio.gather(transcription, transcription2)
+ out = json.loads(transcription)
+ out_text = out["text"]
+ assert "Mary had a little lamb," in out_text
+ out2 = json.loads(transcription2)
+ out_text2 = out2["text"]
+ assert "Edgar Martinez" in out_text2
+
+
+@pytest.mark.asyncio
+async def test_bad_requests(mary_had_lamb, whisper_client):
+ # invalid language
+ with pytest.raises(openai.BadRequestError):
+ await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0
+ )
+
+
+@pytest.mark.asyncio
+async def test_long_audio_request(mary_had_lamb, whisper_client):
+ mary_had_lamb.seek(0)
+ audio, sr = librosa.load(mary_had_lamb)
+ # Add small silence after each audio for repeatability in the split process
+ audio = np.pad(audio, (0, 1600))
+ repeated_audio = np.tile(audio, 10)
+ # Repeated audio to buffer
+ buffer = io.BytesIO()
+ sf.write(buffer, repeated_audio, sr, format="WAV")
+ buffer.seek(0)
+ transcription = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=buffer,
+ language="en",
+ response_format="text",
+ temperature=0.0,
+ )
+ out = json.loads(transcription)
+ out_text = out["text"]
+ out_usage = out["usage"]
+ counts = out_text.count("Mary had a little lamb")
+ assert counts == 10, counts
+ assert out_usage["seconds"] == 161, out_usage["seconds"]
+
+
+@pytest.mark.asyncio
+async def test_completion_endpoints(whisper_client):
+ # text to text model
+ res = await whisper_client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[{"role": "system", "content": "You are a helpful assistant."}],
+ )
+ err = res.error
+ assert err["code"] == 400
+ assert err["message"] == "The model does not support Chat Completions API"
+
+ res = await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
+ err = res.error
+ assert err["code"] == 400
+ assert err["message"] == "The model does not support Completions API"
+
+
+@pytest.mark.asyncio
+async def test_streaming_response(winning_call, whisper_client):
+ transcription = ""
+ res_no_stream = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=winning_call,
+ response_format="json",
+ language="en",
+ temperature=0.0,
+ )
+ res = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=winning_call,
+ language="en",
+ temperature=0.0,
+ stream=True,
+ timeout=30,
+ )
+ # Reconstruct from chunks and validate
+ async for chunk in res:
+ text = chunk.choices[0]["delta"]["content"]
+ transcription += text
+
+ assert transcription == res_no_stream.text
+
+
+@pytest.mark.asyncio
+async def test_stream_options(winning_call, whisper_client):
+ res = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=winning_call,
+ language="en",
+ temperature=0.0,
+ stream=True,
+ extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True),
+ timeout=30,
+ )
+ final = False
+ continuous = True
+ async for chunk in res:
+ if not len(chunk.choices):
+ # final usage sent
+ final = True
+ else:
+ continuous = continuous and hasattr(chunk, "usage")
+ assert final and continuous
+
+
+@pytest.mark.asyncio
+async def test_sampling_params(mary_had_lamb, whisper_client):
+ """
+ Compare sampling with params and greedy sampling to assert results
+ are different when extreme sampling parameters values are picked.
+ """
+ transcription = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=mary_had_lamb,
+ language="en",
+ temperature=0.8,
+ extra_body=dict(
+ seed=42,
+ repetition_penalty=1.9,
+ top_k=12,
+ top_p=0.4,
+ min_p=0.5,
+ frequency_penalty=1.8,
+ presence_penalty=2.0,
+ ),
+ )
+
+ greedy_transcription = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=mary_had_lamb,
+ language="en",
+ temperature=0.0,
+ extra_body=dict(seed=42),
+ )
+
+ assert greedy_transcription.text != transcription.text
+
+
+@pytest.mark.asyncio
+async def test_audio_prompt(mary_had_lamb, whisper_client):
+ prompt = "This is a speech, recorded in a phonograph."
+ # Prompts should not omit the part of original prompt while transcribing.
+ prefix = "The first words I spoke in the original phonograph"
+ transcription = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=mary_had_lamb,
+ language="en",
+ response_format="text",
+ temperature=0.0,
+ )
+ out = json.loads(transcription)["text"]
+ assert prefix in out
+ transcription_wprompt = await whisper_client.audio.transcriptions.create(
+ model=MODEL_NAME,
+ file=mary_had_lamb,
+ language="en",
+ response_format="text",
+ prompt=prompt,
+ temperature=0.0,
+ )
+ out_prompt = json.loads(transcription_wprompt)["text"]
+ assert prefix in out_prompt
diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
index 2b68a653f4600..37e52d2cdf609 100644
--- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
+++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from unittest.mock import MagicMock, patch
+
import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
@@ -132,3 +134,129 @@ def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
+
+
+def test_extract_tool_calls_deeply_nested_json(parser):
+ # Test with deeply nested JSON parameters (5 levels)
+ model_output = (
+ '{"name": "complexTool", '
+ '"parameters": {'
+ '"level1": {'
+ '"level2": {'
+ '"level3": {'
+ '"level4": {'
+ '"value": "deep"'
+ "}}}}}}"
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "complexTool"
+ # Verify the nested structure is preserved in the arguments
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
+
+
+def test_extract_tool_calls_multiple_with_deep_nesting(parser):
+ # Test with multiple tool calls where some have deeply nested parameters
+ model_output = (
+ '{"name": "simpleTool", "parameters": {"value": "test"}}; '
+ '{"name": "complexTool", "parameters": '
+ '{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 2
+
+ # Check first tool call
+ assert result.tool_calls[0].function.name == "simpleTool"
+ import json
+
+ args0 = json.loads(result.tool_calls[0].function.arguments)
+ assert args0["value"] == "test"
+
+ # Check second tool call with deep nesting
+ assert result.tool_calls[1].function.name == "complexTool"
+ args1 = json.loads(result.tool_calls[1].function.arguments)
+ assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
+
+
+def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
+ # Test with quotes and brackets inside quoted string values
+ model_output = (
+ '{"name": "searchTool", '
+ '"parameters": {'
+ '"query": "test {value} [complex]",'
+ '"nested": {"inner": "more {brackets}"}'
+ "}}"
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "searchTool"
+ # Verify the string values are preserved including brackets and quotes
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["query"] == "test {value} [complex]"
+ assert args["nested"]["inner"] == "more {brackets}"
+
+
+def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
+ # Test with escaped quotes in deeply nested JSON
+ model_output = (
+ '{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "parserTool"
+ # Verify escaped quotes are preserved
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["text"] == 'He said "Hello {world}"'
+
+
+def test_extract_tool_calls_missing_name_key(parser):
+ # Test that missing "name" key returns content
+ model_output = '{"parameters": {}}'
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ assert result.content == model_output
+
+
+def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
+ # Test that missing both "parameters" and "arguments" keys returns content
+ model_output = '{"name": "toolWithoutParams"}'
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ assert result.content == model_output
+
+
+def test_regex_timeout_handling(parser):
+ """Test regex timeout is handled gracefully"""
+ fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
+
+ # create a mock regex that raises TimeoutError
+ mock_regex = MagicMock()
+ mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
+
+ with patch.object(parser, "tool_call_start_regex", mock_regex):
+ result = parser.extract_tool_calls(fake_problematic_input, None)
+
+ # should treat as regular text when regex times out
+ assert result.content == fake_problematic_input
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ mock_regex.finditer.assert_called_once()
diff --git a/tests/entrypoints/pooling/correctness/__init__.py b/tests/entrypoints/pooling/basic/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/correctness/__init__.py
rename to tests/entrypoints/pooling/basic/__init__.py
diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/basic/test_encode.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_encode.py
rename to tests/entrypoints/pooling/basic/test_encode.py
diff --git a/tests/entrypoints/pooling/openai/test_truncation.py b/tests/entrypoints/pooling/basic/test_truncation.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_truncation.py
rename to tests/entrypoints/pooling/basic/test_truncation.py
diff --git a/tests/entrypoints/pooling/llm/__init__.py b/tests/entrypoints/pooling/classify/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/__init__.py
rename to tests/entrypoints/pooling/classify/__init__.py
diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/classify/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_classify.py
rename to tests/entrypoints/pooling/classify/test_offline.py
diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/classify/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_classification.py
rename to tests/entrypoints/pooling/classify/test_online.py
diff --git a/tests/entrypoints/pooling/openai/test_vision_classification.py b/tests/entrypoints/pooling/classify/test_online_vision.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_vision_classification.py
rename to tests/entrypoints/pooling/classify/test_online_vision.py
diff --git a/tests/entrypoints/pooling/openai/__init__.py b/tests/entrypoints/pooling/embed/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/__init__.py
rename to tests/entrypoints/pooling/embed/__init__.py
diff --git a/tests/entrypoints/pooling/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/embed/test_correctness_mteb.py
similarity index 100%
rename from tests/entrypoints/pooling/correctness/test_mteb_embed.py
rename to tests/entrypoints/pooling/embed/test_correctness_mteb.py
diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/embed/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_embedding.py
rename to tests/entrypoints/pooling/embed/test_offline.py
diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/embed/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_embedding.py
rename to tests/entrypoints/pooling/embed/test_online.py
diff --git a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/embed/test_online_dimensions.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_embedding_dimensions.py
rename to tests/entrypoints/pooling/embed/test_online_dimensions.py
diff --git a/tests/entrypoints/pooling/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/embed/test_online_long_text.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_embedding_long_text.py
rename to tests/entrypoints/pooling/embed/test_online_long_text.py
diff --git a/tests/entrypoints/pooling/openai/test_vision_embedding.py b/tests/entrypoints/pooling/embed/test_online_vision.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_vision_embedding.py
rename to tests/entrypoints/pooling/embed/test_online_vision.py
diff --git a/tests/entrypoints/pooling/pooling/__init__.py b/tests/entrypoints/pooling/pooling/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/openai/test_pooling.py b/tests/entrypoints/pooling/pooling/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_pooling.py
rename to tests/entrypoints/pooling/pooling/test_online.py
diff --git a/tests/entrypoints/pooling/reward/__init__.py b/tests/entrypoints/pooling/reward/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/reward/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_reward.py
rename to tests/entrypoints/pooling/reward/test_offline.py
diff --git a/tests/entrypoints/pooling/score/__init__.py b/tests/entrypoints/pooling/score/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/correctness/test_mteb_score.py b/tests/entrypoints/pooling/score/test_correctness_mteb.py
similarity index 100%
rename from tests/entrypoints/pooling/correctness/test_mteb_score.py
rename to tests/entrypoints/pooling/score/test_correctness_mteb.py
diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/score/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_score.py
rename to tests/entrypoints/pooling/score/test_offline.py
diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/score/test_online_rerank.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_rerank.py
rename to tests/entrypoints/pooling/score/test_online_rerank.py
diff --git a/tests/entrypoints/pooling/openai/test_score.py b/tests/entrypoints/pooling/score/test_online_score.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_score.py
rename to tests/entrypoints/pooling/score/test_online_score.py
diff --git a/tests/entrypoints/test_responses_utils.py b/tests/entrypoints/test_responses_utils.py
index 91c818374e3fd..893d806b65742 100644
--- a/tests/entrypoints/test_responses_utils.py
+++ b/tests/entrypoints/test_responses_utils.py
@@ -2,6 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
+from openai.types.responses.response_function_tool_call_output_item import (
+ ResponseFunctionToolCallOutputItem,
+)
from openai.types.responses.response_reasoning_item import (
Content,
ResponseReasoningItem,
@@ -76,6 +79,18 @@ class TestResponsesUtils:
== 'Hmm, the user has just started with a simple "Hello,"'
)
+ tool_call_output = ResponseFunctionToolCallOutputItem(
+ id="temp_id",
+ type="function_call_output",
+ call_id="temp",
+ output="1234",
+ status="completed",
+ )
+ formatted_item = construct_chat_message_with_tool_call(tool_call_output)
+ assert formatted_item["role"] == "tool"
+ assert formatted_item["content"] == "1234"
+ assert formatted_item["tool_call_id"] == "temp"
+
item = ResponseReasoningItem(
id="lol",
summary=[],
diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py
index 9662e73321ebe..1a7d5ce0ddc1e 100644
--- a/tests/kernels/attention/test_attention.py
+++ b/tests/kernels/attention/test_attention.py
@@ -13,12 +13,6 @@ from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes
-if not current_platform.is_rocm():
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
-
- from tests.kernels.utils import make_alibi_bias
-
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
@@ -448,129 +442,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
-@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("head_size", HEAD_SIZES)
-@pytest.mark.parametrize("dtype", DTYPES)
-@pytest.mark.parametrize("seed", SEEDS)
-@pytest.mark.parametrize("device", CUDA_DEVICES)
-@pytest.mark.skipif(
- current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
-)
-@torch.inference_mode()
-def test_multi_query_kv_attention(
- num_seqs: int,
- num_heads: tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: str,
- use_alibi: bool = False,
-) -> None:
- current_platform.seed_everything(seed)
- torch.set_default_device(device)
- # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
- # As the xformers library is already tested with its own tests, we can use
- # a smaller MAX_SEQ_LEN here.
- max_len = min(MAX_SEQ_LEN, 4096)
- seq_lens = random.sample(range(1, max_len), num_seqs)
- num_tokens = sum(seq_lens)
-
- scale = float(1.0 / (head_size**0.5))
- num_query_heads, num_kv_heads = num_heads
- qkv = torch.empty(
- num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
- )
- qkv.uniform_(-scale, scale)
- query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
-
- num_queries_per_kv = num_query_heads // num_kv_heads
- if num_queries_per_kv > 1:
- # Handle MQA and GQA
- key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
- value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
- alibi_bias = None
- if use_alibi:
- alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
- attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
- output = torch.empty_like(query)
- start = 0
- # Dynamic sequence length not supported with custom attn_bias.
- for i, seq_len in enumerate(seq_lens):
- end = start + seq_len
- out = xops.memory_efficient_attention_forward(
- query[None, start:end],
- key[None, start:end],
- value[None, start:end],
- attn_bias=attn_bias[i],
- p=0.0,
- scale=scale,
- )
- output[start:end].copy_(out.view_as(query[start:end]))
- start += seq_len
- # xformers.AttentionBias to Tensor for use in reference impl.
- alibi_bias = [
- b.materialize((1, num_query_heads, i, i), device=device).squeeze()
- for b, i in zip(attn_bias, seq_lens)
- ]
- else:
- attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
- output = xops.memory_efficient_attention_forward(
- query.unsqueeze(0),
- key.unsqueeze(0),
- value.unsqueeze(0),
- attn_bias=attn_bias,
- p=0.0,
- scale=scale,
- )
- output = output.squeeze(0)
-
- cu_seq_lens = [0]
- for seq_len in seq_lens:
- cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
- ref_output = ref_multi_query_kv_attention(
- cu_seq_lens,
- query,
- key,
- value,
- scale,
- alibi_bias,
- dtype,
- )
- atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
- rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
- torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
-
-
-@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("head_size", [64])
-@pytest.mark.parametrize("dtype", DTYPES)
-@pytest.mark.parametrize("seed", SEEDS)
-@pytest.mark.parametrize("device", CUDA_DEVICES)
-@pytest.mark.skipif(
- current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
-)
-@torch.inference_mode()
-def test_multi_query_kv_attention_with_alibi(
- num_seqs: int,
- num_heads: tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: str,
-) -> None:
- return test_multi_query_kv_attention(
- num_seqs,
- num_heads,
- head_size,
- dtype,
- seed,
- device,
- use_alibi=True,
- )
-
-
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py
index 9be56a33f76c8..cd34b520ea71b 100644
--- a/tests/kernels/attention/test_attention_selector.py
+++ b/tests/kernels/attention/test_attention_selector.py
@@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
}
DEVICE_REGULAR_ATTN_BACKENDS = {
- "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
+ "cuda": ["FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"],
"cpu": ["CPU_ATTN"],
}
@@ -207,12 +207,6 @@ def test_env(
)
expected = "FLASHINFER"
assert backend.get_name() == expected
- elif name == "XFORMERS":
- backend = get_attn_backend(
- 32, torch.float16, None, block_size, use_mla=use_mla
- )
- expected = "XFORMERS"
- assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py
index 028e164cb801b..acf46d75d62eb 100644
--- a/tests/kernels/attention/test_cache.py
+++ b/tests/kernels/attention/test_cache.py
@@ -921,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
- seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
+ seq_len_tensor = torch.randint(
+ max_seq_len, max_seq_len + 1, (batch_size,), device=device
+ )
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
+ token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
@@ -977,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ total_tokens,
kv_cache_dtype,
scale,
None,
@@ -990,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ total_tokens,
kv_cache_dtype,
scale,
None,
diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py
index a878ac6396ce5..ae3c63cc62d6b 100644
--- a/tests/kernels/attention/test_mha_attn.py
+++ b/tests/kernels/attention/test_mha_attn.py
@@ -24,10 +24,6 @@ from vllm.platforms.rocm import RocmPlatform
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
- # Clear xformers availability cache
- import vllm.attention.layer as layer_module
-
- layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index 638741e91619b..a6977f222408d 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
-from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
@@ -151,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
- router_logits=score,
- use_grouped_topk=False,
- top_k=topk,
+ gating_output=score,
+ topk=topk,
renormalize=False,
- custom_routing_function=Llama4MoE.custom_routing_function,
- scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
@@ -219,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
- router_logits=score,
- use_grouped_topk=False,
- top_k=topk,
+ gating_output=score,
+ topk=topk,
renormalize=False,
- custom_routing_function=Llama4MoE.custom_routing_function,
- scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py
index 5d5a26fbfc2cd..b8148ce06b3fd 100644
--- a/tests/kernels/utils.py
+++ b/tests/kernels/utils.py
@@ -14,7 +14,7 @@ import torch
from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
-from vllm.attention import AttentionType
+from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import (
@@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
)
-def make_alibi_bias(
- alibi_slopes: torch.Tensor,
- num_kv_heads: int,
- dtype: torch.dtype,
- seq_lens: list[int],
-) -> list[Any]:
- """Create ALiBi biases compatible with xFormers attention tests."""
- from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
-
- if alibi_slopes is None:
- return [None for _ in seq_lens]
-
- attn_biases: list[Any] = []
- num_heads = alibi_slopes.shape[0]
- assert num_heads >= num_kv_heads, (
- "ALiBi slopes expect at least as many heads as KV heads"
- )
-
- for seq_len in seq_lens:
- bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
- bias = bias[None, :] - bias[:, None]
-
- padded_len = (seq_len + 7) // 8 * 8
- bias_tensor = torch.empty(
- 1,
- num_heads,
- seq_len,
- padded_len,
- device=alibi_slopes.device,
- dtype=dtype,
- )[:, :, :, :seq_len].copy_(bias)
- bias_tensor.mul_(alibi_slopes[:, None, None])
- attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
-
- return attn_biases
-
-
def _make_metadata_tensors(
seq_lens: list[int] | None,
context_lens: list[int] | None,
@@ -649,23 +612,12 @@ def make_kv_cache(
Returns:
- * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
- * for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
"""
- if backend == "XFORMERS":
- kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
- device
- )
- elif backend == "FLASH_ATTN":
- kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
- device
- )
- else:
- raise ValueError(
- f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
- )
+ if backend != "FLASH_ATTN":
+ raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
+ kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@@ -843,22 +795,14 @@ def assert_actual_matches_ideal(
* output_under_test: actually observed output value
"""
ideal_output = test_params.packed_qkvo.ideal_output
- if backend == "XFORMERS":
- torch.testing.assert_close(
- ideal_output, output_under_test.view_as(ideal_output)
- )
-
- elif backend == "FLASH_ATTN":
- # For FlashAttention override the accuracy thresholds to non default
- # values since we notice a higher difference between the ideal and
- # actual output.
- torch.testing.assert_close(
- ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
- )
- else:
- raise ValueError(
- f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
- )
+ if backend != "FLASH_ATTN":
+ raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
+ # For FlashAttention override the accuracy thresholds to non default
+ # values since we notice a higher difference between the ideal and
+ # actual output.
+ torch.testing.assert_close(
+ ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
+ )
# Copied/modified from torch._refs.__init__.py
diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py
index 2219d470e91a1..b9b1bc59c6ed7 100644
--- a/tests/lora/test_lora_checkpoints.py
+++ b/tests/lora/test_lora_checkpoints.py
@@ -28,12 +28,13 @@ def test_load_checkpoints(
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
- expected_lora_modules: list[str] = []
+ expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
- expected_lora_modules.extend(packed_modules_mapping[module])
+ expected_lora_lst.extend(packed_modules_mapping[module])
else:
- expected_lora_modules.append(module)
+ expected_lora_lst.append(module)
+ expected_lora_modules = set(expected_lora_lst)
if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096
@@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
- expected_lora_modules: list[str] = []
+ expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
- expected_lora_modules.extend(packed_modules_mapping[module])
+ expected_lora_lst.extend(packed_modules_mapping[module])
else:
- expected_lora_modules.append(module)
-
+ expected_lora_lst.append(module)
+ expected_lora_modules = set(expected_lora_lst)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py
index 7d20faef541aa..6a787471c74fd 100644
--- a/tests/lora/test_lora_huggingface.py
+++ b/tests/lora/test_lora_huggingface.py
@@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
- expected_lora_modules: list[str] = []
+ expected_lora_lst: list[str] = []
for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping:
- expected_lora_modules.extend(packed_modules_mapping[module])
+ expected_lora_lst.extend(packed_modules_mapping[module])
else:
- expected_lora_modules.append(module)
-
+ expected_lora_lst.append(module)
+ expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and huggingface id.
diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py
index 1cf8ed602b6a4..e430826461a14 100644
--- a/tests/lora/test_minicpmv_tp.py
+++ b/tests/lora/test_minicpmv_tp.py
@@ -57,10 +57,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
return generated_texts
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
@@ -84,10 +80,6 @@ def test_minicpmv_lora(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
@@ -108,10 +100,6 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py
index 1800ca107a426..7d8c940100ca4 100644
--- a/tests/lora/test_qwen2vl.py
+++ b/tests/lora/test_qwen2vl.py
@@ -2,12 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
-import pytest
-
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
-from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
@@ -142,10 +139,6 @@ QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2-VL dependency xformers incompatible with ROCm",
-)
def test_qwen2vl_lora(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA"""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@@ -156,10 +149,6 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2-VL dependency xformers incompatible with ROCm",
-)
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA through beam search."""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@@ -178,10 +167,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
-)
def test_qwen25vl_lora(qwen25vl_lora_files):
"""Test Qwen 2.5 VL model with LoRA"""
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)
diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py
index 9121284de85b7..7d95dcddca711 100644
--- a/tests/model_executor/test_enabled_custom_ops.py
+++ b/tests/model_executor/test_enabled_custom_ops.py
@@ -5,7 +5,12 @@ import pytest
import torch
from vllm._aiter_ops import rocm_aiter_ops
-from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
+from vllm.config import (
+ CompilationConfig,
+ VllmConfig,
+ get_cached_compilation_config,
+ set_current_vllm_config,
+)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (
GeluAndMul,
@@ -86,6 +91,7 @@ def test_enabled_ops(
backend=backend, mode=compilation_mode, custom_ops=custom_ops
)
)
+ get_cached_compilation_config.cache_clear()
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
diff --git a/tests/model_executor/test_qwen3_omni.py b/tests/model_executor/test_qwen3_omni.py
new file mode 100644
index 0000000000000..c92c61dcd3bc2
--- /dev/null
+++ b/tests/model_executor/test_qwen3_omni.py
@@ -0,0 +1,221 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from unittest.mock import Mock
+
+import pytest
+from transformers import PretrainedConfig
+
+from vllm.multimodal.processing import InputProcessingContext
+
+
+# Helper function to print input IDs with coalesced audio/video tokens.
+def print_input_ids(input_ids):
+ """
+ Print input IDs, compressing consecutive special tokens.
+ - 151675: <|audio_pad|>
+ - 151656: <|video_pad|>
+ """
+ if not input_ids:
+ print("[]")
+ return
+
+ result = []
+ i = 0
+
+ while i < len(input_ids):
+ current_id = input_ids[i]
+
+ # Check if it's a special token that should be compressed
+ if current_id in [151675, 151656]:
+ # Count consecutive occurrences
+ count = 1
+ while i + count < len(input_ids) and input_ids[i + count] == current_id:
+ count += 1
+
+ # Add compressed representation
+ token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
+ result.append(f"{token_name} * {count}")
+ i += count
+ else:
+ # Regular token, just add it
+ result.append(str(current_id))
+ i += 1
+
+ print(", ".join(result))
+
+
+@pytest.fixture
+def mock_qwen3_omni_config():
+ """Create a mock Qwen3OmniMoeThinker config."""
+ config = Mock(spec=PretrainedConfig)
+ # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
+ config.audio_token_id = 151675 # <|audio_pad|>
+ config.video_token_id = 151656 # <|video_pad|>
+ config.image_token_id = 151655 # <|image_pad|>
+ config.audio_start_token_id = 151669 # <|audio_start|>
+ config.audio_end_token_id = 151670 # <|audio_end|>
+ config.vision_start_token_id = 151652 # <|vision_start|>
+ config.position_id_per_seconds = 12.5
+
+ # Vision config
+ vision_config = Mock()
+ vision_config.spatial_merge_size = 2
+ config.vision_config = vision_config
+
+ return config
+
+
+@pytest.fixture
+def mock_processor():
+ """Create a mock HF processor."""
+ from transformers.models.whisper import WhisperFeatureExtractor
+
+ processor = Mock()
+ processor.audio_token = "<|audio_pad|>"
+ processor.image_token = "<|image_pad|>"
+ processor.video_token = "<|video_pad|>"
+
+ # Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
+ feature_extractor = WhisperFeatureExtractor()
+ processor.feature_extractor = feature_extractor
+
+ return processor
+
+
+@pytest.fixture
+def mock_tokenizer():
+ """Create a mock tokenizer."""
+ tokenizer = Mock()
+ # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
+ tokenizer.get_vocab = Mock(
+ return_value={
+ "<|audio_pad|>": 151675,
+ "<|video_pad|>": 151656,
+ "<|image_pad|>": 151655,
+ "<|audio_start|>": 151669,
+ "<|audio_end|>": 151670,
+ "<|vision_start|>": 151652,
+ "<|vision_end|>": 151653,
+ }
+ )
+ tokenizer.encode = Mock(
+ side_effect=lambda x: {
+ "<|vision_start|>": [151652],
+ "<|vision_end|>": [151653],
+ "<|audio_start|>": [151669],
+ "<|audio_end|>": [151670],
+ "<|audio_pad|>": [151675],
+ "<|image_pad|>": [151655],
+ "<|video_pad|>": [151656],
+ }.get(x, [0])
+ )
+ tokenizer.vision_bos_token = "<|vision_start|>"
+ tokenizer.vision_eos_token = "<|vision_end|>"
+ tokenizer.audio_bos_token = "<|audio_start|>"
+ tokenizer.audio_eos_token = "<|audio_end|>"
+ return tokenizer
+
+
+@pytest.fixture
+def mock_image_processor():
+ """Create a mock image processor."""
+ image_processor = Mock()
+ image_processor.merge_size = 2
+ return image_processor
+
+
+def test_qwen3_omni_get_updates_use_audio_in_video(
+ mock_qwen3_omni_config,
+ mock_processor,
+ mock_tokenizer,
+ mock_image_processor,
+):
+ """Test the get_updates_use_audio_in_video method directly."""
+
+ from vllm.model_executor.models.qwen3_omni_moe_thinker import (
+ Qwen3OmniMoeThinkerMultiModalProcessor,
+ Qwen3OmniMoeThinkerProcessingInfo,
+ )
+
+ # Create a mock context
+ mock_ctx = Mock(spec=InputProcessingContext)
+
+ # Create processing info
+ info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
+ info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
+ info.get_hf_processor = Mock(return_value=mock_processor)
+ info.get_tokenizer = Mock(return_value=mock_tokenizer)
+ info.get_image_processor = Mock(return_value=mock_image_processor)
+
+ # Create a mock dummy_inputs builder
+ mock_dummy_inputs = Mock()
+
+ # Create the processor
+ processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
+
+ # Test parameters from reference video
+ # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
+ audio_len = 85
+ video_grid_thw = [6, 36, 64]
+ video_second_per_grid_t = 2.0
+
+ # Call the method
+ updates = processor.get_updates_use_audio_in_video(
+ thinker_config=mock_qwen3_omni_config,
+ audio_len=audio_len,
+ video_grid_thw=video_grid_thw,
+ video_second_per_grid_t=video_second_per_grid_t,
+ )
+
+ # Updated input ids should align with HF implementation.
+ # 151669,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 10,
+ # <|video_pad|> * 1152,
+ # 151670
+ print_input_ids(updates)
+
+ # Verify structure
+ assert isinstance(updates, list)
+ assert len(updates) > 0
+
+ # Verify start and end tokens
+ audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
+ audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
+
+ assert updates[0] == audio_start_token_id
+ assert updates[-1] == audio_end_token_id
+
+ # Verify both audio and video tokens are present
+ audio_token_id = mock_qwen3_omni_config.audio_token_id
+ video_token_id = mock_qwen3_omni_config.video_token_id
+
+ audio_count = updates.count(audio_token_id)
+ video_count = updates.count(video_token_id)
+
+ assert audio_count == audio_len, (
+ f"Expected {audio_len} audio tokens, got {audio_count}"
+ )
+
+ # Calculate expected video token count
+ spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
+ height = video_grid_thw[1] // spatial_merge_size
+ width = video_grid_thw[2] // spatial_merge_size
+ expected_video_count = video_grid_thw[0] * height * width
+
+ assert video_count == expected_video_count, (
+ f"Expected {expected_video_count} video tokens, got {video_count}"
+ )
+
+ # Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
+ expected_total = 1 + audio_len + expected_video_count + 1
+ assert len(updates) == expected_total, (
+ f"Expected {expected_total} total tokens, got {len(updates)}"
+ )
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 758ec54493aa3..c9d4823d52792 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -436,6 +436,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"SolarForCausalLM": _HfExamplesInfo(
"upstage/solar-pro-preview-instruct", trust_remote_code=True
),
+ "TeleChatForCausalLM": _HfExamplesInfo(
+ "chuhac/TeleChat2-35B", trust_remote_code=True
+ ),
"TeleChat2ForCausalLM": _HfExamplesInfo(
"Tele-AI/TeleChat2-3B", trust_remote_code=True
),
@@ -626,6 +629,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True,
),
+ "HunYuanVLForConditionalGeneration": _HfExamplesInfo(
+ "tencent/HunyuanOCR",
+ is_available_online=False,
+ ),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
diff --git a/tests/models/test_gguf_download.py b/tests/models/test_gguf_download.py
new file mode 100644
index 0000000000000..155768ac9bff7
--- /dev/null
+++ b/tests/models/test_gguf_download.py
@@ -0,0 +1,240 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from vllm.config import ModelConfig
+from vllm.config.load import LoadConfig
+from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
+from vllm.model_executor.model_loader.weight_utils import download_gguf
+
+
+class TestGGUFDownload:
+ """Test GGUF model downloading functionality."""
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ def test_download_gguf_single_file(self, mock_download):
+ """Test downloading a single GGUF file."""
+ # Setup mock
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ # Mock glob to return a single file
+ with patch("glob.glob") as mock_glob:
+ mock_glob.side_effect = lambda pattern, **kwargs: (
+ [f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
+ )
+
+ result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
+
+ # Verify download_weights_from_hf was called with correct patterns
+ mock_download.assert_called_once_with(
+ model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
+ cache_dir=None,
+ allow_patterns=[
+ "*-IQ1_S.gguf",
+ "*-IQ1_S-*.gguf",
+ "*/*-IQ1_S.gguf",
+ "*/*-IQ1_S-*.gguf",
+ ],
+ revision=None,
+ ignore_patterns=None,
+ )
+
+ # Verify result is the file path, not folder
+ assert result == f"{mock_folder}/model-IQ1_S.gguf"
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ def test_download_gguf_sharded_files(self, mock_download):
+ """Test downloading sharded GGUF files."""
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ # Mock glob to return sharded files
+ with patch("glob.glob") as mock_glob:
+ mock_glob.side_effect = lambda pattern, **kwargs: (
+ [
+ f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
+ f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
+ ]
+ if "Q2_K" in pattern
+ else []
+ )
+
+ result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
+
+ # Should return the first file after sorting
+ assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ def test_download_gguf_subdir(self, mock_download):
+ """Test downloading GGUF files from subdirectory."""
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ with patch("glob.glob") as mock_glob:
+ mock_glob.side_effect = lambda pattern, **kwargs: (
+ [f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
+ if "Q2_K" in pattern or "**/*.gguf" in pattern
+ else []
+ )
+
+ result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
+
+ assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
+
+ @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
+ @patch("glob.glob", return_value=[])
+ def test_download_gguf_no_files_found(self, mock_glob, mock_download):
+ """Test error when no GGUF files are found."""
+ mock_folder = "/tmp/mock_cache"
+ mock_download.return_value = mock_folder
+
+ with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
+ download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
+
+
+class TestGGUFModelLoader:
+ """Test GGUFModelLoader class methods."""
+
+ @patch("os.path.isfile", return_value=True)
+ def test_prepare_weights_local_file(self, mock_isfile):
+ """Test _prepare_weights with local file."""
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ # Create a simple mock ModelConfig with only the model attribute
+ model_config = MagicMock()
+ model_config.model = "/path/to/model.gguf"
+
+ result = loader._prepare_weights(model_config)
+ assert result == "/path/to/model.gguf"
+ mock_isfile.assert_called_once_with("/path/to/model.gguf")
+
+ @patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
+ """Test _prepare_weights with HTTPS URL."""
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ mock_hf_download.return_value = "/downloaded/model.gguf"
+
+ # Create a simple mock ModelConfig with only the model attribute
+ model_config = MagicMock()
+ model_config.model = "https://huggingface.co/model.gguf"
+
+ result = loader._prepare_weights(model_config)
+ assert result == "/downloaded/model.gguf"
+ mock_hf_download.assert_called_once_with(
+ url="https://huggingface.co/model.gguf"
+ )
+
+ @patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
+ """Test _prepare_weights with repo_id/filename.gguf format."""
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ mock_hf_download.return_value = "/downloaded/model.gguf"
+
+ # Create a simple mock ModelConfig with only the model attribute
+ model_config = MagicMock()
+ model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
+
+ result = loader._prepare_weights(model_config)
+ assert result == "/downloaded/model.gguf"
+ mock_hf_download.assert_called_once_with(
+ repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
+ )
+
+ @patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
+ @patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
+ @patch("vllm.config.model.get_config")
+ @patch("vllm.config.model.is_gguf", return_value=True)
+ @patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_repo_quant_type(
+ self,
+ mock_isfile,
+ mock_download_gguf,
+ mock_is_gguf,
+ mock_get_config,
+ mock_file_exists,
+ mock_get_image_config,
+ ):
+ """Test _prepare_weights with repo_id:quant_type format."""
+ mock_hf_config = MagicMock()
+ mock_hf_config.architectures = ["Qwen3ForCausalLM"]
+
+ class MockTextConfig:
+ max_position_embeddings = 4096
+ sliding_window = None
+ model_type = "qwen3"
+ num_attention_heads = 32
+
+ mock_text_config = MockTextConfig()
+ mock_hf_config.get_text_config.return_value = mock_text_config
+ mock_hf_config.dtype = "bfloat16"
+ mock_get_config.return_value = mock_hf_config
+
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
+
+ model_config = ModelConfig(
+ model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
+ )
+ result = loader._prepare_weights(model_config)
+ # The actual result will be the downloaded file path from mock
+ assert result == "/downloaded/model-IQ1_S.gguf"
+ mock_download_gguf.assert_called_once_with(
+ "unsloth/Qwen3-0.6B-GGUF",
+ "IQ1_S",
+ cache_dir=None,
+ revision=None,
+ ignore_patterns=["original/**/*"],
+ )
+
+ @patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
+ @patch("vllm.config.model.get_config")
+ @patch("vllm.config.model.is_gguf", return_value=False)
+ @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
+ @patch("os.path.isfile", return_value=False)
+ def test_prepare_weights_invalid_format(
+ self,
+ mock_isfile,
+ mock_check_gguf,
+ mock_is_gguf,
+ mock_get_config,
+ mock_get_image_config,
+ ):
+ """Test _prepare_weights with invalid format."""
+ mock_hf_config = MagicMock()
+ mock_hf_config.architectures = ["Qwen3ForCausalLM"]
+
+ class MockTextConfig:
+ max_position_embeddings = 4096
+ sliding_window = None
+ model_type = "qwen3"
+ num_attention_heads = 32
+
+ mock_text_config = MockTextConfig()
+ mock_hf_config.get_text_config.return_value = mock_text_config
+ mock_hf_config.dtype = "bfloat16"
+ mock_get_config.return_value = mock_hf_config
+
+ load_config = LoadConfig(load_format="gguf")
+ loader = GGUFModelLoader(load_config)
+
+ # Create ModelConfig with a valid repo_id to avoid validation errors
+ # Then test _prepare_weights with invalid format
+ model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
+ # Manually set model to invalid format after creation
+ model_config.model = "invalid-format"
+ with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
+ loader._prepare_weights(model_config)
diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py
index 31b65189b5ec3..412b21328a325 100644
--- a/tests/quantization/test_compressed_tensors.py
+++ b/tests/quantization/test_compressed_tensors.py
@@ -10,6 +10,7 @@ import torch
from compressed_tensors.quantization import QuantizationType
from tests.models.utils import check_logprobs_close
+from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensors24,
CompressedTensorsLinearMethod,
@@ -767,3 +768,50 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
output = llm.generate_greedy("Hello my name is", max_tokens=4)
assert output
+
+
+@pytest.mark.skipif(
+ not current_platform.is_cuda(),
+ reason="This test is not for non-CUDA platforms",
+)
+def test_compressed_tensors_moe_ignore_with_model(vllm_runner):
+ """
+ Integration test for MoE layer ignore functionality with a real model.
+
+ This test would verify that when loading a compressed-tensors quantized
+ MoE model where some MoE layers are in the ignore list, those layers
+ use UnquantizedFusedMoEMethod while non-ignored layers use the
+ quantized method.
+
+ Expected model structure:
+ - Compressed-tensors quantized MoE model (e.g., Mixtral-based)
+ - Config with ignore list containing specific MoE layers
+ - Multiple MoE layers where some are quantized and some are not
+ """
+
+ # model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only" # CT 12.3
+ model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable" # CT 12.2
+
+ with vllm_runner(model_path, enforce_eager=True) as llm:
+
+ def check_model(model):
+ from vllm.model_executor.layers.fused_moe import FusedMoE
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
+ CompressedTensorsMoEMethod,
+ )
+
+ # Check layer 0 MoE (should be quantized)
+ layer_quantized = model.model.layers[0].mlp.experts
+ assert isinstance(layer_quantized, FusedMoE)
+ assert isinstance(layer_quantized.quant_method, CompressedTensorsMoEMethod)
+
+ # Check layer 10 MoE (should be unquantized + ignored)
+ layer_unquantized = model.model.layers[3].mlp.experts
+ assert isinstance(layer_unquantized, FusedMoE)
+ assert isinstance(layer_unquantized.quant_method, UnquantizedFusedMoEMethod)
+
+ llm.apply_model(check_model)
+
+ # Verify the model can generate output
+ output = llm.generate_greedy("Hello, my name is", max_tokens=4)
+ assert output
diff --git a/tests/test_config.py b/tests/test_config.py
index 16f68d18fc68b..080e4d2afacc6 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -8,9 +8,20 @@ from unittest.mock import patch
import pytest
from vllm.compilation.backends import VllmBackend
-from vllm.config import ModelConfig, PoolerConfig, VllmConfig, update_config
+from vllm.config import (
+ CompilationConfig,
+ ModelConfig,
+ PoolerConfig,
+ VllmConfig,
+ update_config,
+)
+from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.config.load import LoadConfig
from vllm.config.utils import get_field
+from vllm.config.vllm import (
+ OPTIMIZATION_LEVEL_TO_CONFIG,
+ OptimizationLevel,
+)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@@ -235,6 +246,43 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
assert model_config.pooler_config.pooling_type == pooling_type
+@pytest.mark.parametrize(
+ ("model_id", "expected_is_moe_model"),
+ [
+ ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
+ ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
+ ("RedHatAI/Llama-3.2-1B-FP8", False),
+ ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
+ ("RedHatAI/gpt-oss-20b", True),
+ ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
+ ("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
+ ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
+ ],
+)
+def test_moe_model_detection(model_id, expected_is_moe_model):
+ model_config = ModelConfig(model_id)
+ # Just check that is_moe_model field exists and is a boolean
+ assert model_config.is_model_moe() == expected_is_moe_model
+
+
+@pytest.mark.parametrize(
+ ("model_id", "quantized"),
+ [
+ ("RedHatAI/Qwen3-8B-speculator.eagle3", False),
+ ("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
+ ("RedHatAI/Llama-3.2-1B-FP8", True),
+ ("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
+ ("RedHatAI/gpt-oss-20b", True),
+ ("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
+ ("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
+ ],
+)
+def test_is_quantized(model_id, quantized):
+ model_config = ModelConfig(model_id)
+ # Just check that quantized field exists and is a boolean
+ assert model_config.is_quantized() == quantized
+
+
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@@ -552,3 +600,260 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
+
+
+@pytest.mark.parametrize(
+ ("backend", "custom_ops", "expected"),
+ [
+ ("eager", [], True),
+ ("eager", ["+fused_layernorm"], True),
+ ("eager", ["all", "-fused_layernorm"], False),
+ ("inductor", [], False),
+ ("inductor", ["none", "+fused_layernorm"], True),
+ ("inductor", ["none", "-fused_layernorm"], False),
+ ],
+)
+def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
+ """Test that is_custom_op_enabled works correctly."""
+ config = VllmConfig(
+ compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
+ )
+ assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected
+
+
+def test_vllm_config_defaults_are_none():
+ """Verify that optimization-level defaults are None when not set by user."""
+ # Test all optimization levels to ensure defaults work correctly
+ for opt_level in OptimizationLevel:
+ config = object.__new__(VllmConfig)
+ config.compilation_config = CompilationConfig()
+ config.optimization_level = opt_level
+ config.model_config = None
+
+ # Use the global optimization level defaults
+ default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]
+
+ # Verify that all pass_config values are None before defaults are applied
+ for pass_k in default_config["compilation_config"]["pass_config"]:
+ assert getattr(config.compilation_config.pass_config, pass_k) is None
+
+ # Verify that other config values are None before defaults are applied
+ for k in default_config["compilation_config"]:
+ if k != "pass_config":
+ assert getattr(config.compilation_config, k) is None
+
+
+@pytest.mark.parametrize(
+ ("model_id", "compiliation_config", "optimization_level"),
+ [
+ (
+ None,
+ CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
+ OptimizationLevel.O0,
+ ),
+ (None, CompilationConfig(), OptimizationLevel.O0),
+ (None, CompilationConfig(), OptimizationLevel.O1),
+ (None, CompilationConfig(), OptimizationLevel.O2),
+ (None, CompilationConfig(), OptimizationLevel.O3),
+ (
+ "RedHatAI/Qwen3-8B-speculator.eagle3",
+ CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
+ OptimizationLevel.O2,
+ ),
+ (
+ "RedHatAI/Qwen3-8B-speculator.eagle3",
+ CompilationConfig(),
+ OptimizationLevel.O0,
+ ),
+ (
+ "RedHatAI/Qwen3-8B-speculator.eagle3",
+ CompilationConfig(),
+ OptimizationLevel.O1,
+ ),
+ (
+ "RedHatAI/Qwen3-8B-speculator.eagle3",
+ CompilationConfig(),
+ OptimizationLevel.O2,
+ ),
+ (
+ "RedHatAI/Qwen3-8B-speculator.eagle3",
+ CompilationConfig(),
+ OptimizationLevel.O3,
+ ),
+ ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
+ ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
+ ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
+ ("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
+ ],
+)
+def test_vllm_config_defaults(model_id, compiliation_config, optimization_level):
+ """Test that optimization-level defaults are correctly applied."""
+
+ model_config = None
+ if model_id is not None:
+ model_config = ModelConfig(model_id)
+ vllm_config = VllmConfig(
+ model_config=model_config,
+ compilation_config=compiliation_config,
+ optimization_level=optimization_level,
+ )
+ else:
+ vllm_config = VllmConfig(
+ compilation_config=compiliation_config,
+ optimization_level=optimization_level,
+ )
+ # Use the global optimization level defaults
+ default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]
+
+ # Verify pass_config defaults (nested under compilation_config)
+ pass_config_dict = default_config["compilation_config"]["pass_config"]
+ for pass_k, pass_v in pass_config_dict.items():
+ actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
+ expected = pass_v(vllm_config) if callable(pass_v) else pass_v
+ assert actual == expected, (
+ f"pass_config.{pass_k}: expected {expected}, got {actual}"
+ )
+
+ # Verify other compilation_config defaults
+ compilation_config_dict = default_config["compilation_config"]
+ for k, v in compilation_config_dict.items():
+ if k != "pass_config":
+ actual = getattr(vllm_config.compilation_config, k)
+ expected = v(vllm_config) if callable(v) else v
+ assert actual == expected, (
+ f"compilation_config.{k}: expected {expected}, got {actual}"
+ )
+
+
+def test_vllm_config_callable_defaults():
+ """Test that callable defaults work in the config system.
+
+ Verifies that lambdas in default configs can inspect VllmConfig properties
+ (e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
+ """
+ config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)
+
+ # Callable that checks if model exists
+ has_model = lambda cfg: cfg.model_config is not None
+ assert has_model(config_no_model) is False
+
+ # Test with quantized model
+ quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
+ config_quantized = VllmConfig(
+ model_config=quantized_model, optimization_level=OptimizationLevel.O2
+ )
+ enable_if_quantized = lambda cfg: (
+ cfg.model_config is not None and cfg.model_config.is_quantized()
+ )
+ assert enable_if_quantized(config_quantized) is True
+ assert enable_if_quantized(config_no_model) is False
+
+ # Test with MoE model
+ moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
+ config_moe = VllmConfig(
+ model_config=moe_model, optimization_level=OptimizationLevel.O2
+ )
+ enable_if_sequential = lambda cfg: (
+ cfg.model_config is not None and not cfg.model_config.is_model_moe()
+ )
+ assert enable_if_sequential(config_moe) is False
+ assert enable_if_sequential(config_quantized) is True
+
+
+def test_vllm_config_explicit_overrides():
+ """Test that explicit property overrides work correctly with callable defaults.
+
+ When users explicitly set configuration properties, those values
+ take precedence over callable defaults, across different models and
+ optimization levels.
+ """
+ from vllm.config.compilation import PassConfig
+
+ quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
+ moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
+ regular_model = ModelConfig("Qwen/Qwen1.5-7B")
+
+ # Explicit compilation mode override on O0 (where default is NONE)
+ compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
+ config = VllmConfig(
+ optimization_level=OptimizationLevel.O0,
+ compilation_config=compilation_config,
+ )
+ assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
+
+ # Explicit pass config flags to override defaults
+ pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
+ compilation_config = CompilationConfig(pass_config=pass_config)
+ config = VllmConfig(
+ optimization_level=OptimizationLevel.O0,
+ compilation_config=compilation_config,
+ )
+ assert config.compilation_config.pass_config.enable_noop is True
+ assert config.compilation_config.pass_config.enable_attn_fusion is True
+
+ # Explicit cudagraph mode override on quantized model at O2
+ pass_config = PassConfig(enable_async_tp=True)
+ compilation_config = CompilationConfig(
+ cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
+ )
+ config = VllmConfig(
+ model_config=quantized_model,
+ optimization_level=OptimizationLevel.O2,
+ compilation_config=compilation_config,
+ )
+ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
+ assert config.compilation_config.pass_config.enable_async_tp is True
+ # Mode should still use default for O2
+ assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
+
+ # Different optimization levels with same model
+ config_o0 = VllmConfig(
+ model_config=regular_model, optimization_level=OptimizationLevel.O0
+ )
+ config_o2 = VllmConfig(
+ model_config=regular_model, optimization_level=OptimizationLevel.O2
+ )
+ assert config_o0.compilation_config.mode == CompilationMode.NONE
+ assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
+ assert (
+ config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
+ )
+
+ # Same optimization level across different model types
+ config_moe_o2 = VllmConfig(
+ model_config=moe_model, optimization_level=OptimizationLevel.O2
+ )
+ config_regular_o2 = VllmConfig(
+ model_config=regular_model, optimization_level=OptimizationLevel.O2
+ )
+ config_quantized_o2 = VllmConfig(
+ model_config=quantized_model, optimization_level=OptimizationLevel.O2
+ )
+ # All should have same base compilation settings at O2
+ assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert (
+ config_moe_o2.compilation_config.cudagraph_mode
+ == CUDAGraphMode.FULL_AND_PIECEWISE
+ )
+ assert (
+ config_regular_o2.compilation_config.cudagraph_mode
+ == CUDAGraphMode.FULL_AND_PIECEWISE
+ )
+
+ # Override one field but not others
+ pass_config = PassConfig(enable_noop=False)
+ compilation_config = CompilationConfig(pass_config=pass_config)
+ config = VllmConfig(
+ model_config=regular_model,
+ optimization_level=OptimizationLevel.O2,
+ compilation_config=compilation_config,
+ )
+ # Explicit override should be respected
+ assert config.compilation_config.pass_config.enable_noop is False
+ # Other fields should still use defaults
+ assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py
index 5a162fa8f791b..e8826eb441a24 100644
--- a/tests/test_routing_simulator.py
+++ b/tests/test_routing_simulator.py
@@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including
integration tests with FusedMoE layer.
"""
+import tempfile
+
import pytest
import torch
+from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.distributed import (
+ init_distributed_environment,
+ initialize_model_parallel,
+)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
DistributionBasedRouting,
RoutingSimulator,
@@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device):
# Test different routing strategies
strategies = RoutingSimulator.get_available_strategies()
+ vllm_config = VllmConfig()
+ with set_current_vllm_config(vllm_config):
+ temp_file = tempfile.mkstemp()[1]
+ init_distributed_environment(
+ world_size=1,
+ rank=0,
+ local_rank=0,
+ distributed_init_method=f"file://{temp_file}",
+ )
+ initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ )
+ fused_moe = FusedMoE(
+ num_experts=num_experts,
+ top_k=top_k,
+ hidden_size=hidden_size,
+ intermediate_size=0,
+ use_grouped_topk=False,
+ renormalize=True,
+ )
+
for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
@@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = fused_moe.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
- top_k=top_k,
- use_grouped_topk=False,
- renormalize=True,
- indices_type=torch.long,
)
# Verify output shapes
diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py
index 9af94a6a64a25..77084ec2d9456 100644
--- a/tests/tool_use/test_parallel_tool_calls.py
+++ b/tests/tool_use/test_parallel_tool_calls.py
@@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results(
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
+
+
+@pytest.mark.asyncio
+async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
+ """
+ Ensure only one tool call is returned when parallel_tool_calls is False.
+ """
+
+ models = await client.models.list()
+ model_name: str = models.data[0].id
+ chat_completion = await client.chat.completions.create(
+ messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
+ temperature=0,
+ max_completion_tokens=200,
+ model=model_name,
+ tools=[WEATHER_TOOL, SEARCH_TOOL],
+ logprobs=False,
+ parallel_tool_calls=False,
+ )
+
+ stop_reason = chat_completion.choices[0].finish_reason
+ non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
+
+ # make sure only 1 tool call is present
+ assert len(non_streamed_tool_calls) == 1
+ assert stop_reason == "tool_calls"
+
+ # make the same request, streaming
+ stream = await client.chat.completions.create(
+ model=model_name,
+ messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
+ temperature=0,
+ max_completion_tokens=200,
+ tools=[WEATHER_TOOL, SEARCH_TOOL],
+ logprobs=False,
+ parallel_tool_calls=False,
+ stream=True,
+ )
+
+ finish_reason_count: int = 0
+ tool_call_id_count: int = 0
+
+ async for chunk in stream:
+ # if there's a finish reason make sure it's tools
+ if chunk.choices[0].finish_reason:
+ finish_reason_count += 1
+ assert chunk.choices[0].finish_reason == "tool_calls"
+
+ streamed_tool_calls = chunk.choices[0].delta.tool_calls
+ if streamed_tool_calls and len(streamed_tool_calls) > 0:
+ tool_call = streamed_tool_calls[0]
+ if tool_call.id:
+ tool_call_id_count += 1
+
+ # make sure only 1 streaming tool call is present
+ assert tool_call_id_count == 1
+ assert finish_reason_count == 1
diff --git a/tests/transformers_utils/test_utils.py b/tests/transformers_utils/test_utils.py
index bfe1cec76c138..a8d0b9be9ec29 100644
--- a/tests/transformers_utils/test_utils.py
+++ b/tests/transformers_utils/test_utils.py
@@ -1,11 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from pathlib import Path
+from unittest.mock import patch
+import pytest
from vllm.transformers_utils.utils import (
is_cloud_storage,
is_gcs,
+ is_gguf,
+ is_remote_gguf,
is_s3,
+ split_remote_gguf,
)
@@ -28,3 +34,143 @@ def test_is_cloud_storage():
assert is_cloud_storage("s3://model-path/path-to-model")
assert not is_cloud_storage("/unix/local/path")
assert not is_cloud_storage("nfs://nfs-fqdn.local")
+
+
+class TestIsRemoteGGUF:
+ """Test is_remote_gguf utility function."""
+
+ def test_is_remote_gguf_with_colon_and_slash(self):
+ """Test is_remote_gguf with repo_id:quant_type format."""
+ # Valid quant types
+ assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
+ assert is_remote_gguf("user/repo:Q2_K")
+ assert is_remote_gguf("repo/model:Q4_K")
+ assert is_remote_gguf("repo/model:Q8_0")
+
+ # Invalid quant types should return False
+ assert not is_remote_gguf("repo/model:quant")
+ assert not is_remote_gguf("repo/model:INVALID")
+ assert not is_remote_gguf("repo/model:invalid_type")
+
+ def test_is_remote_gguf_without_colon(self):
+ """Test is_remote_gguf without colon."""
+ assert not is_remote_gguf("repo/model")
+ assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
+
+ def test_is_remote_gguf_without_slash(self):
+ """Test is_remote_gguf without slash."""
+ assert not is_remote_gguf("model.gguf")
+ # Even with valid quant_type, no slash means not remote GGUF
+ assert not is_remote_gguf("model:IQ1_S")
+ assert not is_remote_gguf("model:quant")
+
+ def test_is_remote_gguf_local_path(self):
+ """Test is_remote_gguf with local file path."""
+ assert not is_remote_gguf("/path/to/model.gguf")
+ assert not is_remote_gguf("./model.gguf")
+
+ def test_is_remote_gguf_with_path_object(self):
+ """Test is_remote_gguf with Path object."""
+ assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
+ assert not is_remote_gguf(Path("repo/model"))
+
+ def test_is_remote_gguf_with_http_https(self):
+ """Test is_remote_gguf with HTTP/HTTPS URLs."""
+ # HTTP/HTTPS URLs should return False even with valid quant_type
+ assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
+ assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
+ assert not is_remote_gguf("http://repo/model:Q4_K")
+ assert not is_remote_gguf("https://repo/model:Q8_0")
+
+ def test_is_remote_gguf_with_cloud_storage(self):
+ """Test is_remote_gguf with cloud storage paths."""
+ # Cloud storage paths should return False even with valid quant_type
+ assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
+ assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
+ assert not is_remote_gguf("s3://repo/model:Q4_K")
+ assert not is_remote_gguf("gs://repo/model:Q8_0")
+
+
+class TestSplitRemoteGGUF:
+ """Test split_remote_gguf utility function."""
+
+ def test_split_remote_gguf_valid(self):
+ """Test split_remote_gguf with valid repo_id:quant_type format."""
+ repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
+ assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
+ assert quant_type == "IQ1_S"
+
+ repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
+ assert repo_id == "repo/model"
+ assert quant_type == "Q2_K"
+
+ def test_split_remote_gguf_with_path_object(self):
+ """Test split_remote_gguf with Path object."""
+ repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
+ assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
+ assert quant_type == "IQ1_S"
+
+ def test_split_remote_gguf_invalid(self):
+ """Test split_remote_gguf with invalid format."""
+ # Invalid format (no colon) - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("repo/model")
+
+ # Invalid quant type - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("repo/model:INVALID_TYPE")
+
+ # HTTP URL - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("http://repo/model:IQ1_S")
+
+ # Cloud storage - is_remote_gguf returns False
+ with pytest.raises(ValueError, match="Wrong GGUF model"):
+ split_remote_gguf("s3://bucket/repo/model:Q2_K")
+
+
+class TestIsGGUF:
+ """Test is_gguf utility function."""
+
+ @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
+ def test_is_gguf_with_local_file(self, mock_check_gguf):
+ """Test is_gguf with local GGUF file."""
+ assert is_gguf("/path/to/model.gguf")
+ assert is_gguf("./model.gguf")
+
+ def test_is_gguf_with_remote_gguf(self):
+ """Test is_gguf with remote GGUF format."""
+ # Valid remote GGUF format (repo_id:quant_type with valid quant_type)
+ assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
+ assert is_gguf("repo/model:Q2_K")
+ assert is_gguf("repo/model:Q4_K")
+
+ # Invalid quant_type should return False
+ assert not is_gguf("repo/model:quant")
+ assert not is_gguf("repo/model:INVALID")
+
+ @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
+ def test_is_gguf_false(self, mock_check_gguf):
+ """Test is_gguf returns False for non-GGUF models."""
+ assert not is_gguf("unsloth/Qwen3-0.6B")
+ assert not is_gguf("repo/model")
+ assert not is_gguf("model")
+
+ def test_is_gguf_edge_cases(self):
+ """Test is_gguf with edge cases."""
+ # Empty string
+ assert not is_gguf("")
+
+ # Only colon, no slash (even with valid quant_type)
+ assert not is_gguf("model:IQ1_S")
+
+ # Only slash, no colon
+ assert not is_gguf("repo/model")
+
+ # HTTP/HTTPS URLs
+ assert not is_gguf("http://repo/model:IQ1_S")
+ assert not is_gguf("https://repo/model:Q2_K")
+
+ # Cloud storage
+ assert not is_gguf("s3://bucket/repo/model:IQ1_S")
+ assert not is_gguf("gs://bucket/repo/model:Q2_K")
diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py
index 3310753d2b6d6..c0519155c4ba8 100644
--- a/tests/utils_/test_argparse_utils.py
+++ b/tests/utils_/test_argparse_utils.py
@@ -28,6 +28,7 @@ def parser():
parser.add_argument("--enable-feature", action="store_true")
parser.add_argument("--hf-overrides", type=json.loads)
parser.add_argument("-O", "--compilation-config", type=json.loads)
+ parser.add_argument("--optimization-level", type=int)
return parser
@@ -166,7 +167,7 @@ def test_dict_args(parser):
"--hf-overrides.key2.key4",
"val3",
# Test compile config and compilation mode
- "-O.use_inductor=true",
+ "-O.use_inductor_graph_partition=true",
"-O.backend",
"custom",
"-O1",
@@ -217,9 +218,9 @@ def test_dict_args(parser):
"key15": "-minus.and.dot",
},
}
+ assert parsed_args.optimization_level == 1
assert parsed_args.compilation_config == {
- "mode": 1,
- "use_inductor": True,
+ "use_inductor_graph_partition": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
}
@@ -241,12 +242,13 @@ def test_duplicate_dict_args(caplog_vllm, parser):
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
- assert parsed_args.compilation_config == {"mode": 3}
+ assert parsed_args.optimization_level == 3
+ assert parsed_args.compilation_config == {"mode": 2}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
- assert "-O.mode" in caplog_vllm.text
+ assert "--optimization-level" in caplog_vllm.text
def test_model_specification(
@@ -383,7 +385,7 @@ def test_compilation_mode_string_values(parser):
assert args.compilation_config == {"mode": 0}
args = parser.parse_args(["-O3"])
- assert args.compilation_config == {"mode": 3}
+ assert args.optimization_level == 3
args = parser.parse_args(["-O.mode=NONE"])
assert args.compilation_config == {"mode": "NONE"}
diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py
new file mode 100644
index 0000000000000..77790be6f892b
--- /dev/null
+++ b/tests/v1/attention/test_rocm_attention_backends_selection.py
@@ -0,0 +1,340 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for attention backend selectors."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.platforms import current_platform
+
+# ROCm-specific attention backend selection tests
+pytestmark = pytest.mark.skipif(
+ not current_platform.is_rocm(), reason="ROCm-specific tests"
+)
+
+
+@pytest.fixture
+def mock_vllm_config():
+ """Create a mock VllmConfig for testing."""
+ config = MagicMock()
+ config.model_config.dtype = torch.float16
+ config.model_config.hf_config.architectures = ["LlamaForCausalLM"]
+ config.cache_config.block_size = 16
+ return config
+
+
+@pytest.fixture
+def mock_on_gfx9():
+ """Mock the on_gfx9 function to return True."""
+ with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
+ yield
+
+
+@pytest.mark.parametrize(
+ "env_vars, selected_backend, expected_backend_path",
+ [
+ # Test Case: Explicit FLEX_ATTENTION backend
+ (
+ {},
+ "FLEX_ATTENTION",
+ AttentionBackendEnum.FLEX_ATTENTION.get_path(),
+ ),
+ # Test Case 1: Default (no env vars, no explicit backend)
+ (
+ {},
+ None,
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 2: Explicit TRITON_ATTN backend
+ (
+ {},
+ "TRITON_ATTN",
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 3: Explicit ROCM_ATTN backend
+ (
+ {},
+ "ROCM_ATTN",
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ # Test Case 4: Explicit ROCM_AITER_FA backend
+ (
+ {},
+ "ROCM_AITER_FA",
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend
+ (
+ {},
+ "ROCM_AITER_UNIFIED_ATTN",
+ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
+ ),
+ # Test Case 6: VLLM_ROCM_USE_AITER=1
+ # (defaults to AITER FA when MHA not explicitly disabled)
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
+ (
+ {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"},
+ None,
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
+ (
+ {
+ "VLLM_ROCM_USE_AITER": "1",
+ "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1",
+ },
+ None,
+ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
+ ),
+ # Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
+ (
+ {"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
+ None,
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ # Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "TRITON_ATTN",
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
+ # (explicitly disabled)
+ (
+ {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
+ None,
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "ROCM_ATTN",
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ ],
+)
+def test_standard_attention_backend_selection(
+ env_vars,
+ selected_backend,
+ expected_backend_path,
+ mock_vllm_config,
+ mock_on_gfx9,
+ monkeypatch,
+):
+ """Test standard attention backend selection with various configurations."""
+ # Set environment variables
+ for key, value in env_vars.items():
+ monkeypatch.setenv(key, value)
+
+ # Import after setting env vars to ensure they're picked up
+ # Reload envs to pick up new environment variables
+ import importlib
+
+ import vllm.envs as envs
+
+ importlib.reload(envs)
+
+ # Convert string backend to enum if provided
+ backend_enum = None
+ if selected_backend:
+ backend_enum = getattr(AttentionBackendEnum, selected_backend)
+
+ # Get the backend class path
+ from vllm.platforms.rocm import RocmPlatform
+
+ backend_path = RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=False,
+ )
+ assert backend_path == expected_backend_path
+
+
+@pytest.mark.parametrize(
+ "env_vars, selected_backend, block_size, expected_backend_path, should_raise",
+ [
+ # Test Case 1: TRITON_MLA with block_size != 1
+ (
+ {},
+ "TRITON_MLA",
+ 16,
+ AttentionBackendEnum.TRITON_MLA.get_path(),
+ False,
+ ),
+ # Test Case 2: TRITON_MLA with block_size == 1 (should raise)
+ (
+ {},
+ "TRITON_MLA",
+ 1,
+ None,
+ True,
+ ),
+ # Test Case 3: ROCM_AITER_MLA with block_size == 1
+ (
+ {},
+ "ROCM_AITER_MLA",
+ 1,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise)
+ (
+ {},
+ "ROCM_AITER_MLA",
+ 16,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ 1,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16
+ # (should use ROCM_AITER_MLA now, as it supports block_size 16)
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ 16,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "TRITON_MLA",
+ 16,
+ AttentionBackendEnum.TRITON_MLA.get_path(),
+ False,
+ ),
+ # Test Case 8: Explicit ROCM_AITER_TRITON_MLA
+ (
+ {},
+ "ROCM_AITER_TRITON_MLA",
+ 16,
+ AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path(),
+ False,
+ ),
+ ],
+)
+def test_mla_backend_selection(
+ env_vars,
+ selected_backend,
+ block_size,
+ expected_backend_path,
+ should_raise,
+ mock_vllm_config,
+ monkeypatch,
+):
+ """Test MLA backend selection with various configurations."""
+ # Set environment variables
+ for key, value in env_vars.items():
+ monkeypatch.setenv(key, value)
+
+ # Import after setting env vars
+ # Reload envs
+ import importlib
+
+ import vllm.envs as envs
+
+ importlib.reload(envs)
+
+ # Mock is_aiter_mla_enabled based on env vars and block_size
+ aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1"
+
+ mock_rocm_ops = MagicMock()
+ mock_rocm_ops.is_mla_enabled.return_value = aiter_enabled
+ mock_aiter_module = MagicMock()
+ mock_aiter_module.rocm_aiter_ops = mock_rocm_ops
+
+ with patch.dict("sys.modules", {"vllm._aiter_ops": mock_aiter_module}):
+ # Convert string backend to enum if provided
+ backend_enum = None
+ if selected_backend:
+ backend_enum = getattr(AttentionBackendEnum, selected_backend)
+
+ from vllm.platforms.rocm import RocmPlatform
+
+ if should_raise:
+ with pytest.raises(ValueError):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=block_size,
+ use_mla=True,
+ has_sink=False,
+ use_sparse=False,
+ )
+ else:
+ backend_path = RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=block_size,
+ use_mla=True,
+ has_sink=False,
+ use_sparse=False,
+ )
+ assert backend_path == expected_backend_path
+
+
+def test_aiter_fa_requires_gfx9(mock_vllm_config):
+ """Test that ROCM_AITER_FA requires gfx9 architecture."""
+ from vllm.platforms.rocm import RocmPlatform
+
+ # Mock on_gfx9 to return False
+ with (
+ patch("vllm.platforms.rocm.on_gfx9", return_value=False),
+ pytest.raises(
+ ValueError,
+ match="only supported on gfx9",
+ ),
+ ):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=False,
+ )
+
+
+def test_sparse_not_supported(mock_vllm_config):
+ """Test that sparse attention is not supported on ROCm."""
+ from vllm.platforms.rocm import RocmPlatform
+
+ with pytest.raises(
+ AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
+ ):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=None,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=True,
+ )
diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py
index dea89babd4b47..df3d53332c7cd 100644
--- a/tests/v1/attention/utils.py
+++ b/tests/v1/attention/utils.py
@@ -340,4 +340,11 @@ full_cg_backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
),
+ "RocmAttn": BackendConfig(
+ name="RocmAttn",
+ env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
+ comp_config={
+ "cudagraph_mode": "FULL",
+ },
+ ),
}
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index 24611a4aaa1b8..58a7a2692bfc8 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
)
# Test case 1: Requires additional lookahead tokens
- kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
+ kv_cache_manager = KVCacheManager(
+ kv_cache_config=config, max_model_len=100, hash_block_size=block_size
+ )
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
@@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
- kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
+ kv_cache_manager = KVCacheManager(
+ kv_cache_config=config, max_model_len=100, hash_block_size=block_size
+ )
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
request,
@@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
- kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
+ kv_cache_manager = KVCacheManager(
+ kv_cache_config=config, max_model_len=100, hash_block_size=block_size
+ )
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
@@ -1436,7 +1442,67 @@ def test_get_kv_cache_config_one_worker():
],
)
- # different hidden size
+ # 6 full + 5 sliding, pad to 6 full + 6 sliding. This is a typical case for gpt-oss
+ # eagle where there is only one more full attention layer than sliding window layers
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(),
+ "layer_2": new_kv_cache_spec(),
+ "layer_3": new_kv_cache_spec(),
+ "layer_4": new_kv_cache_spec(),
+ "layer_5": new_kv_cache_spec(),
+ "layer_6": new_kv_cache_spec(),
+ "layer_7": new_sliding_window_spec(),
+ "layer_8": new_sliding_window_spec(),
+ "layer_9": new_sliding_window_spec(),
+ "layer_10": new_sliding_window_spec(),
+ "layer_11": new_sliding_window_spec(),
+ }
+
+ kv_cache_config_hybrid = get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 6 * 32]
+ )[0]
+ print(kv_cache_config_hybrid)
+ assert kv_cache_config_hybrid == KVCacheConfig(
+ num_blocks=32,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_1", "layer_7"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_2", "layer_8"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_3", "layer_9"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_4", "layer_10"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_5", "layer_11"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_6"],
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer_1", "layer_2", "layer_3", "layer_4", "layer_5", "layer_6"],
+ new_kv_cache_spec(),
+ ),
+ KVCacheGroupSpec(
+ ["layer_7", "layer_8", "layer_9", "layer_10", "layer_11"],
+ new_sliding_window_spec(),
+ ),
+ ],
+ )
+
+ # different hidden size but same type, use UniformTypeKVCacheSpecs
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=128),
"layer_2": new_kv_cache_spec(head_size=64),
@@ -1460,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
],
)
+ # Different hidden size and different type, align by different block size
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(head_size=64),
+ "layer_2": new_sliding_window_spec(head_size=32),
+ }
+ kv_cache_config_hybrid = get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
+ )[0]
+ assert kv_cache_config_hybrid == KVCacheConfig(
+ num_blocks=32,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
+ KVCacheGroupSpec(
+ ["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
+ ),
+ ],
+ )
+
+ # different hidden size that cannot be aligned by using different block size
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(head_size=64),
+ "layer_2": new_sliding_window_spec(head_size=96),
+ }
+
+ with pytest.raises(NotImplementedError):
+ get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
+ )[0]
+
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_configs(
diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py
index 2291f363731f2..64fd5ab1dd9aa 100644
--- a/tests/v1/core/test_prefix_caching.py
+++ b/tests/v1/core/test_prefix_caching.py
@@ -134,6 +134,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
hash_fn = sha256
@@ -416,6 +418,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# the default hash function is sha256
hash_fn = sha256
@@ -523,6 +526,7 @@ def test_decode():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@@ -585,6 +589,7 @@ def test_evict():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
last_token_id = 5 * 16 + 7
@@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Allocate 1 block and cache it.
@@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Allocate a block and cache it.
@@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
+ hash_block_size=block_size,
)
req1 = make_request(
@@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
+ hash_block_size=block_size,
)
# Req:
# Block 0: [0, 1, 2, 3]
@@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size = 4
- block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
+ )
# Req:
# Block 0/4: [0, 1, 2, 3]
@@ -921,6 +932,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
@@ -1020,6 +1032,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# 3 complete blocks and an incomplete block with 11 tokens.
@@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
@@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]
@@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
+ hash_block_size=block_size,
log_stats=False, # Disable logging stats
)
assert manager.prefix_cache_stats is None
@@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block():
- pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
+ pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
@@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
+ hash_block_size=block_size,
)
num_tokens = block_size * blocks_to_cache
@@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
+ hash_block_size=block_size,
)
# Test with LoRA request
@@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
+ hash_block_size=block_size,
)
# Request with 3 full blocks (48 tokens)
@@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
+ hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
@@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
+ hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
@@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0
+def test_different_block_size():
+ block_size = 16
+ # full attention and sliding window attention layers have the same page size:
+ # (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
+ kv_cache_config = KVCacheConfig(
+ num_blocks=100,
+ kv_cache_tensors=[],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer1"],
+ FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
+ ),
+ KVCacheGroupSpec(
+ ["layer2"],
+ SlidingWindowSpec(
+ block_size,
+ 1,
+ 1,
+ torch.float32,
+ sliding_window=2 * block_size,
+ ),
+ ),
+ ],
+ )
+ manager = KVCacheManager(
+ kv_cache_config=kv_cache_config,
+ max_model_len=8192,
+ enable_caching=True,
+ hash_block_size=block_size,
+ )
+
+ # 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
+ common_token_ids = [i for i in range(10) for _ in range(block_size)]
+
+ req0 = make_request("0", common_token_ids, block_size, sha256)
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
+ assert not computed_blocks.blocks[0]
+ assert not computed_blocks.blocks[1]
+ assert num_computed_tokens == 0
+ blocks = manager.allocate_slots(
+ req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
+ )
+ assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
+ req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
+ assert len(computed_blocks.blocks[0]) == 3
+ assert len(computed_blocks.blocks[1]) == 6
+ assert num_computed_tokens == 6 * 16
+
+ req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
+ assert len(computed_blocks.blocks[0]) == 3
+ assert len(computed_blocks.blocks[1]) == 6
+ assert num_computed_tokens == 6 * 16
+
+ # Evict some blocks to make sliding window cache hit length 5*16
+ # But should return 4 * 16 because full attention cache hit length must be
+ # a multiple of 32
+ manager.block_pool.cached_block_hash_to_block.pop(
+ make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
+ )
+ manager.block_pool.cached_block_hash_to_block.pop(
+ make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
+ )
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
+ assert len(computed_blocks.blocks[0]) == 2
+ assert len(computed_blocks.blocks[1]) == 4
+ assert num_computed_tokens == 4 * 16
+
+
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py
index 09acde6e08faa..fe4153e609971 100644
--- a/tests/v1/core/test_scheduler.py
+++ b/tests/v1/core/test_scheduler.py
@@ -641,6 +641,34 @@ def test_schedule_concurrent_batches(
scheduler.update_from_output(scheduler_output1, model_runner_output)
+@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
+def test_schedule_order(enable_chunked_prefill: bool):
+ scheduler = create_scheduler(
+ max_num_batched_tokens=1024,
+ max_num_seqs=3,
+ enable_chunked_prefill=enable_chunked_prefill,
+ )
+
+ # long requests
+ requests = create_requests(num_requests=2, num_tokens=800)
+ # short requests
+ requests += create_requests(num_requests=2, num_tokens=10)
+
+ for request in requests:
+ scheduler.add_request(request)
+
+ scheduler_output1 = scheduler.schedule()
+
+ if enable_chunked_prefill:
+ # When enable chunked prefill, long requests will be chunked.
+ assert len(scheduler_output1.scheduled_new_reqs) == 2
+ else:
+ # When disable chunked prefill, should not skip the long requests,
+ # and scheduling subsequent short requests in advance,
+ # even though there is still token budgets remaining.
+ assert len(scheduler_output1.scheduled_new_reqs) == 1
+
+
def test_preempt_during_execution():
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
# because block 0 is reserved as the null block.
diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py
index a27f32938c08b..e6a69dc8a949a 100644
--- a/tests/v1/core/test_single_type_kv_cache_manager.py
+++ b/tests/v1/core/test_single_type_kv_cache_manager.py
@@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
attention_chunk_size=4,
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool
)
@@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False,
+ alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
sliding_window=4,
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length):
@@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False,
+ alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
attention_chunk_size=4,
)
- block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
+ block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
@@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
sliding_window=4,
)
- block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
+ block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
@@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
sliding_window=4, # Placeholder value, not related to test result
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
@@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
attention_chunk_size=4, # Placeholder value, not related to test result
)
- block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
+ block_pool = BlockPool(
+ num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
+ )
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py
index 6830f68736453..7537c7a60476b 100644
--- a/tests/v1/core/utils.py
+++ b/tests/v1/core/utils.py
@@ -42,6 +42,7 @@ def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
+ enable_chunked_prefill: bool = True,
enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
@@ -76,7 +77,7 @@ def create_scheduler(
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
- enable_chunked_prefill=True,
+ enable_chunked_prefill=enable_chunked_prefill,
async_scheduling=async_scheduling,
)
model_config = ModelConfig(
diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py
index bb953e5c70c8c..314e7094ef97f 100644
--- a/tests/v1/cudagraph/test_cudagraph_dispatch.py
+++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py
@@ -42,12 +42,24 @@ def _create_vllm_config(
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
+ mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1()
+ # mimic VllmConfig.__post_init__
+ if compilation_config.cudagraph_capture_sizes:
+ compilation_config.max_cudagraph_capture_size = (
+ compilation_config.cudagraph_capture_sizes[-1]
+ )
+
+ compilation_config.post_init_cudagraph_sizes()
+ mock_config.pad_for_cudagraph = (
+ lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
+ )
+
return mock_config
@@ -109,9 +121,11 @@ class TestCudagraphDispatcher:
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
- uniform_decode=False,
+ uniform=False,
+ )
+ rt_mode, key = dispatcher.dispatch(
+ num_tokens=8, uniform_decode=False, has_lora=False
)
- rt_mode, key = dispatcher.dispatch(desc_full_exact)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
@@ -122,32 +136,37 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
- desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
- rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
+ desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
+ rt_mode, key = dispatcher.dispatch(
+ num_tokens=8, uniform_decode=True, has_lora=False
+ )
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
- assert key == desc_uniform_exact.non_uniform
+ assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
- assert key == desc_uniform_exact.non_uniform
+ assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
- desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
- rt_mode, key = dispatcher.dispatch(desc_no_match)
+ rt_mode, key = dispatcher.dispatch(
+ num_tokens=15, uniform_decode=False, has_lora=False
+ )
assert rt_mode == CUDAGraphMode.NONE
- assert key is None
+ assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
- desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
- rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
+ desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
+ rt_mode, key = dispatcher.dispatch(
+ num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
+ )
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
- assert key == desc_full_exact.non_uniform
+ assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
else:
assert rt_mode == CUDAGraphMode.NONE
diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py
index d6bde16eba36b..12621d493e549 100644
--- a/tests/v1/cudagraph/test_cudagraph_mode.py
+++ b/tests/v1/cudagraph/test_cudagraph_mode.py
@@ -35,14 +35,22 @@ def temporary_environ(env_vars):
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
-combo_cases_1 = [
- ("FA3", "FULL", True),
- ("FA3", "FULL_AND_PIECEWISE", True),
- ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
- ("FA2", "FULL_AND_PIECEWISE", True),
- ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
- ("FlashInfer", "FULL_AND_PIECEWISE", True),
-]
+if current_platform.is_rocm():
+ combo_cases_1 = [
+ ("RocmAttn", "FULL", True),
+ ("RocmAttn", "FULL_AND_PIECEWISE", True),
+ ("TritonAttn", "FULL", True),
+ ("TritonAttn", "FULL_AND_PIECEWISE", True),
+ ]
+else:
+ combo_cases_1 = [
+ ("FA3", "FULL", True),
+ ("FA3", "FULL_AND_PIECEWISE", True),
+ ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
+ ("FA2", "FULL_AND_PIECEWISE", True),
+ ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
+ ("FlashInfer", "FULL_AND_PIECEWISE", True),
+ ]
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
@@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
-combo_cases_2 = [
- ("FA2", "FULL", CompilationMode.NONE, True),
- ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "PIECEWISE", CompilationMode.NONE, False),
- ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
- ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
- ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "NONE", CompilationMode.NONE, True),
- ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
-]
+if current_platform.is_rocm():
+ combo_cases_2 = [
+ ("RocmAttn", "FULL", CompilationMode.NONE, True),
+ ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
+ ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
+ ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
+ ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "NONE", CompilationMode.NONE, True),
+ ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
+ ]
+else:
+ combo_cases_2 = [
+ ("FA2", "FULL", CompilationMode.NONE, True),
+ ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "PIECEWISE", CompilationMode.NONE, True),
+ ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
+ ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
+ ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "NONE", CompilationMode.NONE, True),
+ ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
+ ]
@pytest.mark.parametrize(
diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py
index b9e2daafb8705..4311547baccf1 100644
--- a/tests/v1/determinism/test_batch_invariance.py
+++ b/tests/v1/determinism/test_batch_invariance.py
@@ -159,7 +159,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
"backend",
BACKENDS,
)
-@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch
):
@@ -429,7 +428,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
"backend",
BACKENDS,
)
-@pytest.mark.forked
def test_logprobs_without_batch_invariance_should_fail(
backend, monkeypatch: pytest.MonkeyPatch
):
@@ -646,7 +644,6 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
-@pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch
):
diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py
index ecbb6a1126933..0d7da107728b4 100644
--- a/tests/v1/determinism/utils.py
+++ b/tests/v1/determinism/utils.py
@@ -8,6 +8,7 @@ import torch
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.platforms import current_platform
+from vllm.utils.flashinfer import has_flashinfer
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
@@ -16,9 +17,11 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS: list[str] = [
"FLASH_ATTN",
- "FLASHINFER",
]
+if has_flashinfer():
+ BACKENDS.append("FLASHINFER")
+
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA")
diff --git a/tests/v1/distributed/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py
index 60f9017184ea0..3b5f2e5e8d72f 100644
--- a/tests/v1/distributed/test_async_llm_dp.py
+++ b/tests/v1/distributed/test_async_llm_dp.py
@@ -12,6 +12,7 @@ from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
+from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
@@ -84,6 +85,10 @@ async def test_load(
if async_scheduling and data_parallel_backend == "ray":
# TODO(NickLucche) Re-enable when async scheduling is supported
pytest.skip("Async scheduling is not supported with ray")
+ elif data_parallel_backend == "ray" and current_platform.is_rocm():
+ pytest.skip(
+ "Ray as the distributed executor backend is not supported with ROCm."
+ )
stats_loggers = {}
@dataclass
diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py
new file mode 100644
index 0000000000000..9f6a6614fc1fd
--- /dev/null
+++ b/tests/v1/distributed/test_eagle_dp.py
@@ -0,0 +1,77 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import asyncio
+import os
+from contextlib import AsyncExitStack
+from dataclasses import replace
+
+import pytest
+
+from vllm import SamplingParams
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.sampling_params import RequestOutputKind
+from vllm.v1.engine.async_llm import AsyncLLM
+
+DP_SIZE = int(os.getenv("DP_SIZE", 2))
+
+
+@pytest.mark.asyncio
+async def test_run_eagle_dp():
+ target_model = "meta-llama/Llama-3.1-8B-Instruct"
+ draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
+
+ engine_args = AsyncEngineArgs(
+ model=target_model,
+ tokenizer_mode="auto",
+ enforce_eager=False,
+ tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
+ data_parallel_size=DP_SIZE,
+ data_parallel_backend="mp", # ray takes more time
+ trust_remote_code=True,
+ max_model_len=16384,
+ )
+
+ eagle_engine_args = replace(
+ engine_args,
+ speculative_config={
+ "model": draft_model,
+ "method": "eagle",
+ "num_speculative_tokens": 3,
+ },
+ )
+
+ prompt = "This is a test of data parallel with eagle"
+ num_expected_tokens = 100
+ sampling_params = SamplingParams(
+ min_tokens=num_expected_tokens,
+ max_tokens=num_expected_tokens,
+ ignore_eos=True,
+ output_kind=RequestOutputKind.FINAL_ONLY,
+ temperature=0,
+ )
+
+ async def generate_with_timeout(given_engine: AsyncLLM):
+ async for out in given_engine.generate(
+ request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
+ ):
+ token_ids = out.outputs[0].token_ids
+ assert len(token_ids) == num_expected_tokens
+ return token_ids
+
+ async def engine_create_and_generate(engine_args: AsyncEngineArgs):
+ async with AsyncExitStack() as after:
+ engine = AsyncLLM.from_engine_args(engine_args)
+ after.callback(engine.shutdown)
+
+ token_ids = await asyncio.wait_for(
+ generate_with_timeout(engine), timeout=30
+ )
+
+ assert not engine.output_processor.has_unfinished_requests()
+ return token_ids
+
+ token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
+ token_ids_no_eagle = await engine_create_and_generate(engine_args)
+
+ # Test for correctness
+ assert token_ids_with_eagle == token_ids_no_eagle
diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py
index 00d93e1ba0b53..945276376d665 100644
--- a/tests/v1/e2e/test_async_scheduling.py
+++ b/tests/v1/e2e/test_async_scheduling.py
@@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}
+ test_sampling_params = [
+ dict(),
+ dict(logprobs=2),
+ ]
+
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
@@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True),
]
- run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
+ run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
@dynamo_config.patch(cache_size_limit=16)
diff --git a/tests/v1/ec_connector/integration/test_epd_correctness.py b/tests/v1/ec_connector/integration/test_epd_correctness.py
index 69c4c58e349b9..616d34441ab8e 100644
--- a/tests/v1/ec_connector/integration/test_epd_correctness.py
+++ b/tests/v1/ec_connector/integration/test_epd_correctness.py
@@ -237,9 +237,8 @@ def main():
for i, prompt_data in enumerate(test_prompts):
print(
- f"\nRunning prompt {i + 1}/{len(test_prompts)}: {
- prompt_data['description']
- }"
+ f"\nRunning prompt {i + 1}/{len(test_prompts)}: "
+ f"{prompt_data['description']}"
)
output_str = run_chat_completion(
diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py
index d1b037b7956cf..85f108786c05a 100644
--- a/tests/v1/entrypoints/llm/test_struct_output_generate.py
+++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py
@@ -3,7 +3,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
-from dataclasses import fields
from enum import Enum
from typing import TYPE_CHECKING, Any
@@ -21,7 +20,6 @@ from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import (
- GuidedDecodingParams,
SamplingParams,
StructuredOutputsParams,
)
@@ -108,23 +106,6 @@ class CarDescription(BaseModel):
car_type: CarType
-def test_guided_decoding_deprecated():
- with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"):
- guided_decoding = GuidedDecodingParams(json_object=True)
-
- structured_outputs = StructuredOutputsParams(json_object=True)
- assert fields(guided_decoding) == fields(structured_outputs)
-
- with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
- sp1 = SamplingParams(guided_decoding=guided_decoding)
-
- with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
- sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
-
- assert sp1 == sp2
- assert sp1.structured_outputs == guided_decoding
-
-
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE,
@@ -899,13 +880,11 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
output_json = json.loads(generated_text)
-@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
-def test_structured_output_with_structural_tag(
- guided_decoding_backend: str,
-):
+@pytest.mark.parametrize("backend", ["xgrammar"])
+def test_structured_output_with_structural_tag(backend: str):
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
- guided_decoding_backend=guided_decoding_backend,
+ structured_outputs_config=StructuredOutputsConfig(backend=backend),
)
structural_tag_config = {
@@ -923,7 +902,7 @@ def test_structured_output_with_structural_tag(
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
- guided_decoding=StructuredOutputsParams(
+ structured_outputs=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)
),
)
diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py
index f51001a6ec12a..7cd23805c599d 100644
--- a/tests/v1/kv_connector/unit/test_backwards_compatibility.py
+++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py
@@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
@@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
@@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py
index 406d4c0b4c1fd..57474a3dc01e7 100644
--- a/tests/v1/kv_offload/test_cpu_offloading.py
+++ b/tests/v1/kv_offload/test_cpu_offloading.py
@@ -20,6 +20,8 @@ ATTN_BACKENDS = ["FLASH_ATTN"]
if current_platform.is_cuda():
ATTN_BACKENDS.append("FLASHINFER")
+elif current_platform.is_rocm():
+ ATTN_BACKENDS = ["TRITON_ATTN"]
class MockSubscriber:
diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py
index d0f1b703fcb92..89669ee8b71a0 100644
--- a/tests/v1/worker/test_gpu_model_runner.py
+++ b/tests/v1/worker/test_gpu_model_runner.py
@@ -5,8 +5,8 @@ import numpy as np
import pytest
import torch
-from vllm.attention import Attention
from vllm.attention.backends.abstract import MultipleOf
+from vllm.attention.layer import Attention
from vllm.config import (
CacheConfig,
ModelConfig,
diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py
index f987b09e603e7..bcf5611e35228 100644
--- a/tests/v1/worker/test_utils.py
+++ b/tests/v1/worker/test_utils.py
@@ -7,7 +7,7 @@ from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache():
- from vllm.attention import Attention
+ from vllm.attention.layer import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
@@ -35,7 +35,7 @@ def test_bind_kv_cache():
def test_bind_kv_cache_non_attention():
- from vllm.attention import Attention
+ from vllm.attention.layer import Attention
# example from Jamba PP=2
ctx = {
diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh
index 1cea1bef8dbc9..88be5cd778fff 100755
--- a/tools/ep_kernels/install_python_libraries.sh
+++ b/tools/ep_kernels/install_python_libraries.sh
@@ -1,22 +1,68 @@
#!/usr/bin/env bash
set -ex
-# usage: ./build.sh [workspace_dir] [mode]
-# mode: "install" (default) → install directly into current Python env
-# "wheel" → build wheels into WORKSPACE/dist
+# usage: ./install_python_libraries.sh [options]
+# --workspace workspace directory (default: ./ep_kernels_workspace)
+# --mode "install" (default) or "wheel"
+# --pplx-ref pplx-kernels commit hash
+# --deepep-ref DeepEP commit hash
+
+CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
+PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"}
+DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"}
+NVSHMEM_VER=3.3.9
+WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace}
+MODE=${MODE:-install}
+
+# Parse arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --workspace)
+ if [[ -z "$2" || "$2" =~ ^- ]]; then
+ echo "Error: --workspace requires an argument." >&2
+ exit 1
+ fi
+ WORKSPACE="$2"
+ shift 2
+ ;;
+ --mode)
+ if [[ -z "$2" || "$2" =~ ^- ]]; then
+ echo "Error: --mode requires an argument." >&2
+ exit 1
+ fi
+ MODE="$2"
+ shift 2
+ ;;
+ --pplx-ref)
+ if [[ -z "$2" || "$2" =~ ^- ]]; then
+ echo "Error: --pplx-ref requires an argument." >&2
+ exit 1
+ fi
+ PPLX_COMMIT_HASH="$2"
+ shift 2
+ ;;
+ --deepep-ref)
+ if [[ -z "$2" || "$2" =~ ^- ]]; then
+ echo "Error: --deepep-ref requires an argument." >&2
+ exit 1
+ fi
+ DEEPEP_COMMIT_HASH="$2"
+ shift 2
+ ;;
+ *)
+ echo "Error: Unknown argument '$1'" >&2
+ exit 1
+ ;;
+ esac
+done
-WORKSPACE=${1:-$(pwd)/ep_kernels_workspace}
-MODE=${2:-install}
mkdir -p "$WORKSPACE"
WHEEL_DIR="$WORKSPACE/dist"
mkdir -p "$WHEEL_DIR"
-NVSHMEM_VER=3.3.9
pushd "$WORKSPACE"
-CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
-
# install dependencies if not installed
if [ -z "$VIRTUAL_ENV" ]; then
uv pip install --system cmake torch ninja
@@ -133,7 +179,7 @@ do_build \
"https://github.com/ppl-ai/pplx-kernels" \
"pplx-kernels" \
"setup.py" \
- "12cecfd" \
+ "$PPLX_COMMIT_HASH" \
""
# build DeepEP
@@ -141,7 +187,7 @@ do_build \
"https://github.com/deepseek-ai/DeepEP" \
"DeepEP" \
"setup.py" \
- "73b6ea4" \
+ "$DEEPEP_COMMIT_HASH" \
"export NVSHMEM_DIR=$WORKSPACE/nvshmem; "
if [ "$MODE" = "wheel" ]; then
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
index db79b3f5e8bcb..a8f472d147a0d 100644
--- a/vllm/_aiter_ops.py
+++ b/vllm/_aiter_ops.py
@@ -294,6 +294,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
) -> None:
from aiter.mla import mla_decode_fwd
@@ -308,6 +310,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
+ q_scale=q_scale,
+ kv_scale=kv_scale,
)
@@ -322,6 +326,8 @@ def _rocm_aiter_mla_decode_fwd_fake(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
) -> None:
pass
@@ -806,6 +812,8 @@ class rocm_aiter_ops:
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
@@ -818,6 +826,8 @@ class rocm_aiter_ops:
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
+ q_scale=q_scale,
+ kv_scale=kv_scale,
)
@staticmethod
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 0f625a7945241..4a1bcc761f994 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -2201,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
- batch_size: int,
+ token_to_seq: torch.Tensor,
+ num_tokens: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: torch.Tensor | None = None,
@@ -2211,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ num_tokens,
kv_cache_dtype,
scale,
seq_starts,
diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py
index 8b4dc4013362e..e69de29bb2d1d 100644
--- a/vllm/attention/__init__.py
+++ b/vllm/attention/__init__.py
@@ -1,19 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from vllm.attention.backends.abstract import (
- AttentionBackend,
- AttentionMetadata,
- AttentionType,
-)
-from vllm.attention.layer import Attention
-from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend
-
-__all__ = [
- "Attention",
- "AttentionBackend",
- "AttentionMetadata",
- "AttentionType",
- "get_attn_backend",
- "get_mamba_attn_backend",
-]
diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py
index bd7e81b15bfc3..c290670eeacb0 100644
--- a/vllm/attention/backends/abstract.py
+++ b/vllm/attention/backends/abstract.py
@@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch
-from vllm.model_executor.layers.linear import ColumnParallelLinear
-from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
-
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
+ from vllm.model_executor.layers.linear import ColumnParallelLinear
+ from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
@@ -178,8 +177,6 @@ class AttentionBackend(ABC):
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
- from vllm.attention import AttentionType
-
return attn_type == AttentionType.DECODER
@classmethod
@@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
) -> torch.Tensor:
raise NotImplementedError
- def fused_output_quant_supported(self, quant_key: QuantKey):
+ def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
@@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
- kv_b_proj: ColumnParallelLinear,
+ kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
) -> None:
raise NotImplementedError
diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py
index 6747cf7743b14..125e4e3827747 100644
--- a/vllm/attention/backends/registry.py
+++ b/vllm/attention/backends/registry.py
@@ -43,7 +43,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
- XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = (
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index a8e796a1eab63..62ac38751aa01 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -10,8 +10,11 @@ import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
-from vllm.attention import AttentionType
-from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
+from vllm.attention.backends.abstract import (
+ AttentionBackend,
+ AttentionType,
+ MLAAttentionImpl,
+)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
@@ -51,31 +54,6 @@ else:
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
-USE_XFORMERS_OPS = None
-
-
-def check_xformers_availability():
- global USE_XFORMERS_OPS
- if USE_XFORMERS_OPS is not None:
- return USE_XFORMERS_OPS
-
- if current_platform.is_cuda() and current_platform.has_device_capability(100):
- # Xformers FA is not compatible with B200
- USE_XFORMERS_OPS = False
- else:
- try:
- from importlib.util import find_spec
-
- find_spec("xformers.ops")
- USE_XFORMERS_OPS = True
- except ImportError:
- USE_XFORMERS_OPS = False
-
- # the warning only needs to be shown once
- if not USE_XFORMERS_OPS:
- logger.warning("Xformers is not available, falling back.")
-
- return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
@@ -533,7 +511,6 @@ class MultiHeadAttention(nn.Module):
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
@@ -549,12 +526,6 @@ class MultiHeadAttention(nn.Module):
)
)
- if (
- self.attn_backend == AttentionBackendEnum.XFORMERS
- and not check_xformers_availability()
- ):
- self.attn_backend = AttentionBackendEnum.TORCH_SDPA
-
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
@@ -614,12 +585,6 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
-
- out = xops.memory_efficient_attention_forward(
- query, key, value, scale=self.scale
- )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py
index 48fcc6fa736bb..0ced0028ded9e 100644
--- a/vllm/attention/layers/chunked_local_attention.py
+++ b/vllm/attention/layers/chunked_local_attention.py
@@ -5,6 +5,7 @@ import functools
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
+from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
@@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec,
)
-from ..layer import Attention
-
@functools.lru_cache
def create_chunked_local_attention_backend(
diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py
index 5b44c7e3e7ec8..068fd0a0eb7d0 100644
--- a/vllm/attention/layers/cross_attention.py
+++ b/vllm/attention/layers/cross_attention.py
@@ -25,15 +25,6 @@ from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__)
-def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
- """Gets the max number of encoder input tokens from the config."""
- sc = vllm_config.scheduler_config
- assert sc and isinstance(sc.max_num_encoder_input_tokens, int), (
- "max_num_encoder_input_tokens must be int for enc-dec models"
- )
- return sc.max_num_encoder_input_tokens
-
-
def _get_cross_slot_mapping(
encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
@@ -93,23 +84,32 @@ def create_cross_attention_backend(
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
- max_encoder_len = _get_max_encoder_len(self.vllm_config)
+ max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
+ # Any computed tokens indicated decode step>1 (no chunked prefill)
+ num_cache_decodes = (
+ (common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
+ )
+ if num_cache_decodes > 0:
+ # CrossAttn KV cache has already been populated on first decoder step,
+ # skip slot_mapping calculation for requests that do not need
+ # reshape_and_cache.
+ num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
+ new_metadata.encoder_seq_lens_cpu = np.where(
+ num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
+ )
- new_metadata.seq_lens = torch.full(
- (new_metadata.num_reqs,),
- max_encoder_len,
- dtype=torch.int32,
- device=self.device,
- )
- new_metadata.seq_lens_cpu = torch.full(
- (new_metadata.num_reqs,),
- max_encoder_len,
- dtype=torch.int32,
- device="cpu",
+ # seq_lens is provided by model runner: initial encoder input length is
+ # needed here to know how many tokens to attend to from the cached
+ # cross-attention KV cache.
+ new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
+ new_metadata.seq_lens_cpu = torch.from_numpy(
+ common_attn_metadata.encoder_seq_lens_cpu
)
+
+ # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping(
- new_metadata.encoder_seq_lens,
+ new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,
diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py
index 67c5f7dbba9c0..af6766bdd1615 100644
--- a/vllm/attention/ops/common.py
+++ b/vllm/attention/ops/common.py
@@ -194,7 +194,6 @@ def _cp_lse_common(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
- assert out.is_contiguous()
return out, lse
diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py
index 3c87a24afd9c7..74e4d778ded87 100644
--- a/vllm/attention/ops/triton_merge_attn_states.py
+++ b/vllm/attention/ops/triton_merge_attn_states.py
@@ -20,7 +20,11 @@ def merge_attn_states(
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
-
+ # We assume the output stride on num_head is not always as same as the
+ # `suffix_output` and `prefix_output`, as them might be padded by the attention
+ # backend.
+ prefix_head_stride = prefix_output.stride(1)
+ output_head_stride = output.stride(1)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
@@ -29,6 +33,8 @@ def merge_attn_states(
prefix_lse,
suffix_output,
suffix_lse,
+ prefix_head_stride,
+ output_head_stride,
head_size,
padded_head_size,
output_lse is not None,
@@ -43,6 +49,8 @@ def merge_attn_states_kernel(
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
+ prefix_head_stride,
+ output_head_stride,
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
@@ -79,15 +87,15 @@ def merge_attn_states_kernel(
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
- + token_idx * num_heads * HEAD_SIZE
- + head_idx * HEAD_SIZE
+ + token_idx * num_heads * prefix_head_stride
+ + head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
- + token_idx * num_heads * HEAD_SIZE
- + head_idx * HEAD_SIZE
+ + token_idx * num_heads * prefix_head_stride
+ + head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
@@ -99,7 +107,10 @@ def merge_attn_states_kernel(
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
- output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
+ output
+ + token_idx * num_heads * output_head_stride
+ + head_idx * output_head_stride
+ + head_arange,
out,
mask=head_mask,
)
diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py
index 06a9f7cd82266..46f8f5117f7a7 100644
--- a/vllm/attention/ops/vit_attn_wrappers.py
+++ b/vllm/attention/ops/vit_attn_wrappers.py
@@ -3,7 +3,7 @@
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
-`to_list` in xformers attn, or `.item()` in flash attention)
+`.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
@@ -19,42 +19,6 @@ import torch.nn.functional as F
from vllm.utils.torch_utils import direct_register_custom_op
-def xformers_attn_seqlens_wrapper(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
- return context_layer
-
-
-def xformers_attn_seqlens_wrapper_fake(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- b, s, h, d = q.shape
- return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
-
-
-direct_register_custom_op(
- op_name="xformers_attn_seqlens_wrapper",
- op_func=xformers_attn_seqlens_wrapper,
- fake_impl=xformers_attn_seqlens_wrapper_fake,
-)
-
-
-def vit_xformers_attn_wrapper(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
-
-
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index e9af08b2316d2..ad19b58aa155c 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -36,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
- return None if backend_name is None else AttentionBackendEnum[backend_name]
+ if backend_name is None:
+ return None
+ if backend_name == "XFORMERS":
+ raise ValueError(
+ "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
+ "details). Please select a supported attention backend."
+ )
+ return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py
index dddb050ec180e..519303c0bfa0a 100644
--- a/vllm/benchmarks/serve.py
+++ b/vllm/benchmarks/serve.py
@@ -1005,7 +1005,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Key-value pairs (e.g, --header x-additional-info=0.3.3) "
"for headers to be passed with each request. These headers override "
"per backend constants and values set via environment variable, and "
- "will be overriden by other arguments (such as request ids).",
+ "will be overridden by other arguments (such as request ids).",
)
parser.add_argument(
"--max-concurrency",
@@ -1138,7 +1138,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--percentile-metrics",
type=str,
default=None,
- help="Comma-separated list of selected metrics to report percentils. "
+ help="Comma-separated list of selected metrics to report percentiles. "
"This argument specifies the metrics to report percentiles. "
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'If not specified, defaults to "ttft,tpot,itl" for generative models '
diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py
index 45ac446a7aedf..1298e4acbd87d 100644
--- a/vllm/benchmarks/sweep/serve.py
+++ b/vllm/benchmarks/sweep/serve.py
@@ -211,6 +211,7 @@ def run_combs(
output_dir: Path,
num_runs: int,
dry_run: bool,
+ links: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@@ -226,6 +227,14 @@ def run_combs(
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
+ should_run = all(
+ serve_key in serve_comb
+ and bench_key in bench_comb
+ and serve_comb[serve_key] == bench_comb[bench_key]
+ for serve_key, bench_key in links
+ )
+ if not should_run:
+ continue
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
@@ -262,6 +271,7 @@ class SweepServeArgs:
num_runs: int
dry_run: bool
resume: str | None
+ link_vars: list[tuple[str, str]] | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@@ -285,7 +295,7 @@ class SweepServeArgs:
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
-
+ link_vars = cls.parse_link_vars(args.link_vars)
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -301,6 +311,7 @@ class SweepServeArgs:
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
+ link_vars=link_vars,
)
@classmethod
@@ -376,8 +387,28 @@ class SweepServeArgs:
"parameter combinations for which there are still no output files.",
)
+ parser.add_argument(
+ "--link-vars",
+ type=str,
+ default="",
+ help=(
+ "Comma-separated list of linked variables between serve and bench, "
+ "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
+ ),
+ )
+
return parser
+ @staticmethod
+ def parse_link_vars(s: str) -> list[tuple[str, str]]:
+ if not s:
+ return []
+ pairs = []
+ for item in s.split(","):
+ a, b = item.split("=")
+ pairs.append((a.strip(), b.strip()))
+ return pairs
+
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -397,6 +428,7 @@ def run_main(args: SweepServeArgs):
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
+ links=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py
index 1e66f21ff6388..1773913d0b6c6 100644
--- a/vllm/compilation/backends.py
+++ b/vllm/compilation/backends.py
@@ -11,6 +11,7 @@ import pprint
import time
from collections.abc import Callable, Sequence
from contextlib import contextmanager
+from copy import deepcopy
from functools import partial
from typing import Any
@@ -63,13 +64,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
else:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
- else:
- assert compilation_config.backend == "eager", (
- "Custom backends not supported with CompilationMode.VLLM_COMPILE"
- )
-
+ elif compilation_config.backend == "eager":
logger.debug("Using EagerAdaptor")
return EagerAdaptor()
+ else:
+ logger.debug("Using custom backend: %s", compilation_config.backend)
+ compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
+ assert isinstance(compiler, CompilerInterface)
+ return compiler
class CompilerManager:
@@ -428,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.vllm_backend.compiler_manager.compile(
submod,
args,
- self.compilation_config.inductor_compile_config,
+ self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
@@ -530,6 +532,9 @@ class VllmBackend:
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
+ # Copy of CompilationConfig.inductor_compile_config +
+ # an entry for PostGradPassManager
+ inductor_config: dict[str, Any]
def __init__(
self,
@@ -545,7 +550,10 @@ class VllmBackend:
self.prefix = prefix or model_tag
# Passes to run on the graph post-grad.
- self.post_grad_pass_manager = PostGradPassManager()
+ self.pass_manager = resolve_obj_by_qualname(
+ current_platform.get_pass_manager_cls()
+ )()
+ self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
@@ -557,29 +565,30 @@ class VllmBackend:
self.compilation_config
)
+ # Deepcopy the inductor config to detach the post-grad custom pass
+ # from CompilationConfig.
+ # We want to avoid PostGradPassManager in CompilationConfig because
+ # in future we need PostGradPassManager.uuid() to be executed
+ # only at compile time.
+ self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def configure_post_pass(self):
- config = self.compilation_config
- self.post_grad_pass_manager.configure(self.vllm_config)
+ self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
- inductor_config = config.inductor_compile_config
- PASS_KEY = "post_grad_custom_post_pass"
- if PASS_KEY in inductor_config:
- if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
- # PassManager already added to config, make sure it's correct
- assert (
- inductor_config[PASS_KEY].uuid()
- == self.post_grad_pass_manager.uuid()
+ if self.pass_key in self.inductor_config:
+ if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
+ raise ValueError(
+ "PostGradPassManager can not be kept in CompilationConfig."
)
else:
# Config should automatically wrap all inductor passes
- assert isinstance(inductor_config[PASS_KEY], InductorPass)
- self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
- inductor_config[PASS_KEY] = self.post_grad_pass_manager
+ assert isinstance(self.inductor_config[self.pass_key], InductorPass)
+ self.pass_manager.add(self.inductor_config[self.pass_key])
+ self.inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs
@@ -638,9 +647,7 @@ class VllmBackend:
self.compilation_config.local_cache_dir = local_cache_dir
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
- disable_cache = not is_compile_cache_enabled(
- self.compilation_config.inductor_compile_config
- )
+ disable_cache = not is_compile_cache_enabled(self.inductor_config)
if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py
index 6297d9f995aa4..ce482572b401b 100644
--- a/vllm/compilation/caching.py
+++ b/vllm/compilation/caching.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
import inspect
import os
import pickle
@@ -14,6 +13,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
try:
from torch._dynamo.aot_compile import SerializableCallable
@@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
- return hashlib.md5(
+ return safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py
index 11cf0f85c1787..7deaba1a99fad 100644
--- a/vllm/compilation/compiler_interface.py
+++ b/vllm/compilation/compiler_interface.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
-import hashlib
import os
from collections.abc import Callable
from contextlib import ExitStack
@@ -16,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
+from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
- hash_str = hashlib.md5(
- str(factors).encode(), usedforsecurity=False
- ).hexdigest()[:10]
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
+ :10
+ ]
return hash_str
def initialize_cache(
@@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
- hash_str = hashlib.md5(
- str(factors).encode(), usedforsecurity=False
- ).hexdigest()[:10]
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
+ :10
+ ]
return hash_str
def initialize_cache(
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index 11a18c0e6bb78..6d9da1c488c6d 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -24,6 +24,7 @@ from vllm.config import (
get_current_vllm_config,
set_current_vllm_config,
)
+from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -104,6 +105,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
+ shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -161,6 +163,14 @@ def support_torch_compile(
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
+
+ `shape_invariants` is a function that gets compiled right before forward.
+ The function should have the torch._check calls that are needed to set
+ the relationships between different input sizes. For example:
+ torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
+ This enforces constraints on the symbolic shapes without hardcoding
+ specific values. It is needed for some models to avoid data dependent
+ errors.
"""
def cls_decorator_helper(cls: _T) -> _T:
@@ -199,7 +209,11 @@ def support_torch_compile(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
- cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
+ cls,
+ inferred_dynamic_arg_dims,
+ mark_unbacked_dims,
+ enable_if,
+ shape_invariants,
)
if cls is not None:
@@ -242,6 +256,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
+ shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -276,11 +291,12 @@ def _support_torch_compile(
old_init(self, **kwargs)
self.vllm_config = vllm_config
+ self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
- vllm_config.compilation_config.mode
+ self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
@@ -289,29 +305,38 @@ def _support_torch_compile(
if self.do_not_compile:
return
+ self._check_shape_invariants = shape_invariants
+
compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
- def _mark_dynamic_inputs(mod, *args, **kwargs):
+ def _mark_dynamic_inputs(mod, type, *args, **kwargs):
+ def mark_dynamic(arg, dims):
+ if type == DynamicShapesType.UNBACKED:
+ torch._dynamo.decorators.mark_unbacked(arg, dims)
+ else:
+ torch._dynamo.mark_dynamic(arg, dims)
+
sig = inspect.signature(mod.__class__.forward)
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
+
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
- torch._dynamo.mark_dynamic(arg, dims)
+ mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
- torch._dynamo.mark_dynamic(tensor, dims)
+ mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
@@ -338,6 +363,7 @@ def _support_torch_compile(
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
+ ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
@@ -352,6 +378,14 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
+ # Validate that AOT compile is not used with unbacked dynamic
+ # shapes. aot_compile re-allocates backed symbols post dynamo!
+ if ds_type == DynamicShapesType.UNBACKED:
+ raise ValueError(
+ "AOT compilation is not compatible with UNBACKED dynamic shapes. "
+ "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
+ "when VLLM_USE_AOT_COMPILE is enabled."
+ )
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
@@ -401,7 +435,12 @@ def _support_torch_compile(
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
- _mark_dynamic_inputs(self, *args, **kwargs)
+ _mark_dynamic_inputs(
+ self,
+ ds_type,
+ *args,
+ **kwargs,
+ )
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
@@ -417,9 +456,7 @@ def _support_torch_compile(
# properly when any of these files change.
# 1. the file containing the top-level forward function
- self.vllm_config.compilation_config.traced_files.add(
- original_code_object.co_filename
- )
+ self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
@@ -429,7 +466,7 @@ def _support_torch_compile(
def patched_inline_call(self_):
code = self_.f_code
- self.vllm_config.compilation_config.traced_files.add(code.co_filename)
+ self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -445,12 +482,18 @@ def _support_torch_compile(
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
+ # Prepare backed_size_oblivious config patch if needed
+ fx_config_patches = {}
+ if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
+ fx_config_patches["backed_size_oblivious"] = True
+
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
+ torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:
diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py
index 4f44faece75e5..6dcbbd85d7031 100644
--- a/vllm/compilation/fusion_attn.py
+++ b/vllm/compilation/fusion_attn.py
@@ -10,7 +10,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py
index 2931580afbbb0..e535d2c461c6e 100644
--- a/vllm/compilation/piecewise_backend.py
+++ b/vllm/compilation/piecewise_backend.py
@@ -107,7 +107,7 @@ class PiecewiseBackend:
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
- self.compilation_config.inductor_compile_config,
+ self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/qk_norm_rope_fusion.py
index e3c399e079063..794cd8e3fce56 100644
--- a/vllm/compilation/qk_norm_rope_fusion.py
+++ b/vllm/compilation/qk_norm_rope_fusion.py
@@ -9,7 +9,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py
index 493e57f97f0f4..b120c85bf232e 100644
--- a/vllm/compilation/wrapper.py
+++ b/vllm/compilation/wrapper.py
@@ -6,6 +6,7 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
+from typing import Any
import torch
import torch._C._dynamo.guards
@@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
+ def check_invariants_and_forward(self, *args, **kwargs):
+ assert hasattr(self, "_check_shape_invariants")
+ self._check_shape_invariants(*args, **kwargs)
+
+ return self.forward(*args, **kwargs)
+
def __init__(self):
self.compiled = False
@@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
+ # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
+ from vllm.compilation.decorators import DynamicShapesType
+
+ ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
+ compiled_ptr: Any = self.forward
+ if ds_type == DynamicShapesType.UNBACKED:
+ if envs.VLLM_USE_BYTECODE_HOOK:
+ # reason is that bytecode does this hack torch._dynamo.eval_frame.
+ # remove_from_cache(self.original_code_object()) to force a new
+ # re-compilation.
+ raise ValueError(
+ "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
+ )
+ compiled_ptr = self.check_invariants_and_forward
+
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
@@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
logger.warning(msg)
self._compiled_callable = torch.compile(
- self.forward,
+ compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
diff --git a/vllm/config/cache.py b/vllm/config/cache.py
index ef6928d8ebd5c..00530846fce00 100644
--- a/vllm/config/cache.py
+++ b/vllm/config/cache.py
@@ -144,7 +144,7 @@ class CacheConfig:
kv_offloading_backend: KVOffloadingBackend | None = None
"""The backend to use for KV cache offloading. Supported backends include
- 'native' (vLLM native CPU offloading), 'lmcache' This option must be used
+ 'native' (vLLM native CPU offloading), 'lmcache' This option must be used
together with kv_offloading_size."""
def compute_hash(self) -> str:
@@ -167,8 +167,6 @@ class CacheConfig:
"num_gpu_blocks_override",
"enable_prefix_caching",
"prefix_caching_hash_algo",
- # `cpu_offload_gb` does not use `torch.compile` yet.
- "cpu_offload_gb",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
# Post-init/derived counters
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
index 9b5309598d0e2..da2c100dae3dc 100644
--- a/vllm/config/compilation.py
+++ b/vllm/config/compilation.py
@@ -8,7 +8,7 @@ from dataclasses import asdict, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
-from pydantic import TypeAdapter, field_validator
+from pydantic import Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
@@ -97,19 +97,25 @@ class PassConfig:
This is separate from general `CompilationConfig` so that inductor passes
don't all have access to full configuration - that would create a cycle as
- the `PassManager` is set as a property of config."""
+ the `PassManager` is set as a property of config.
- enable_fusion: bool = False
+ You must pass PassConfig to VLLMConfig constructor via the CompilationConfig
+ constructor. VLLMConfig's post_init does further initialization.
+ If used outside of the VLLMConfig, some fields may be left in an
+ improper state.
+ """
+
+ enable_fusion: bool = Field(default=None)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
- enable_attn_fusion: bool = False
+ enable_attn_fusion: bool = Field(default=None)
"""Whether to enable the custom attention+quant fusion pass."""
- enable_noop: bool = False
+ enable_noop: bool = Field(default=None)
"""Whether to enable the custom no-op elimination pass."""
- enable_sequence_parallelism: bool = False
+ enable_sequence_parallelism: bool = Field(default=None)
"""Whether to enable sequence parallelism."""
- enable_async_tp: bool = False
+ enable_async_tp: bool = Field(default=None)
"""Whether to enable async TP."""
- enable_fi_allreduce_fusion: bool = False
+ enable_fi_allreduce_fusion: bool = Field(default=None)
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
@@ -167,6 +173,22 @@ class PassConfig:
"""
return InductorPass.hash_dict(asdict(self))
+ @field_validator(
+ "enable_fusion",
+ "enable_attn_fusion",
+ "enable_noop",
+ "enable_sequence_parallelism",
+ "enable_async_tp",
+ "enable_fi_allreduce_fusion",
+ mode="wrap",
+ )
+ @classmethod
+ def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
+ """Skip validation if the value is `None` when initialisation is delayed."""
+ if value is None:
+ return value
+ return handler(value)
+
def __post_init__(self) -> None:
if not self.enable_noop:
if self.enable_fusion:
@@ -192,10 +214,64 @@ class PassConfig:
self.enable_qk_norm_rope_fusion = False
+class DynamicShapesType(str, enum.Enum):
+ """Types of dynamic shapes handling in torch.compile().
+ see Dynamic shapes and vllm guard dropping in torch_compile.md
+ for more details."""
+
+ BACKED = "backed"
+ """Use backed dynamic shapes. torch.compile() guards on backed dynamic
+ shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
+ without encountering branching on those ranges."""
+
+ UNBACKED = "unbacked"
+ """Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
+ 0/1 specialized, but may throw data dependent errors when branches require
+ their value without explicit unbacked handling."""
+
+ BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
+ """Experimental flag that treats backed symbols as unbacked when explicit
+ unbacked handling is defined."""
+
+
+@config
+@dataclass
+class DynamicShapesConfig:
+ """Configuration to control/debug torch compile dynamic shapes."""
+
+ type: DynamicShapesType = DynamicShapesType.BACKED
+ """Controls the type of dynamic shapes handling to use with torch.compile().
+
+ - BACKED: Default PyTorch behavior with potential guards ignored.
+ - UNBACKED: No guards guaranteed (most sound) but may throw
+ data dependent errors.
+ - BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
+ backed/unbacked.
+ """
+
+ # TODO add a debug mode to fail
+
+ def compute_hash(self) -> str:
+ """
+ Provide a hash for DynamicShapesConfig
+ """
+
+ from vllm.config.utils import get_hash_factors, hash_factors
+
+ factors = get_hash_factors(self, {})
+ return hash_factors(factors)
+
+
@config
@dataclass
class CompilationConfig:
- """Configuration for compilation. It has three parts:
+ """Configuration for compilation.
+
+ You must pass CompilationConfig to VLLMConfig constructor.
+ VLLMConfig's post_init does further initialization. If used outside of the
+ VLLMConfig, some fields will be left in an improper state.
+
+ It has three parts:
- Top-level Compilation control:
- [`mode`][vllm.config.CompilationConfig.mode]
@@ -216,7 +292,6 @@ class CompilationConfig:
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation:
- - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
@@ -235,14 +310,14 @@ class CompilationConfig:
"""
# Top-level Compilation control
- level: int | None = None
+ level: int = Field(default=None)
"""
Level is deprecated and will be removed in the next release,
either 0.12.0 or 0.11.2 whichever is soonest.
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
- mode: CompilationMode | None = None
+ mode: CompilationMode = Field(default=None)
"""The compilation approach used for torch.compile-based compilation of the
model.
@@ -283,9 +358,9 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
- compilation mode is 3, the backend is used for the piecewise compilation
- (it sees a part of the graph). The backend can not be custom for compilation
- mode 3, i.e. the backend must be either eager or inductor. Furthermore,
+ compilation mode is 3, the backend supports both whole graph and piecewise
+ compilation, available backends include eager, inductor, and custom backends,
+ the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
@@ -300,7 +375,7 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
- disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
+ disabled when running with Inductor: mode>=VLLM_COMPILE and backend="inductor".
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
@@ -322,35 +397,19 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
- Currently, this only works for `Qwen2_5_vl` on selected platforms.
+ Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
# Inductor capture
- use_inductor: bool | None = None
- """
- Whether to use inductor compilation.
-
- This flag is deprecated and will be removed in the next release 0.12.0.
- Please use the 'backend' option instead.
-
- - False: inductor compilation is not used. graph runs in eager
- (custom_ops enabled by default).
- - True: inductor compilation is used (custom_ops disabled by default).
- One graph for symbolic shape and one graph per size in compile_sizes
- are compiled using configurations in inductor_compile_config.
-
- This setting is ignored if mode512) that would
greatly increase startup time with limited performance benefit.
"""
+
+ dynamic_shapes_config: DynamicShapesConfig = field(
+ default_factory=DynamicShapesConfig
+ )
+ """Configuration for dynamic shapes options"""
+
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
+
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
@@ -530,6 +596,7 @@ class CompilationConfig:
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
+
factors["pass_config"] = self.pass_config.compute_hash()
return hash_factors(factors)
@@ -609,6 +676,20 @@ class CompilationConfig:
)
return value
+ @field_validator(
+ "level",
+ "mode",
+ "cudagraph_mode",
+ "use_inductor_graph_partition",
+ mode="wrap",
+ )
+ @classmethod
+ def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
+ """Skip validation if the value is `None` when initialisation is delayed."""
+ if value is None:
+ return value
+ return handler(value)
+
def __post_init__(self) -> None:
if self.level is not None:
logger.warning(
@@ -701,16 +782,8 @@ class CompilationConfig:
f"Invalid backend for piecewise compilation: {self.backend}"
)
- if self.use_inductor is not None:
- logger.warning_once(
- "The 'use_inductor' flag is deprecated and will be "
- "removed in the next release (v0.12.0). "
- "Please use the 'backend' option instead.",
- )
- self.backend = "inductor" if self.use_inductor else "eager"
-
if self.backend == "":
- self.backend = current_platform.simple_compile_backend
+ self.backend = current_platform.get_compile_backend()
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
@@ -742,9 +815,7 @@ class CompilationConfig:
assert self.mode == CompilationMode.VLLM_COMPILE
if self.backend not in ["eager", "inductor"]:
- raise ValueError(
- f"Invalid backend for piecewise compilation: {self.backend}"
- )
+ logger.info("Using OOT custom backend for compilation.")
from vllm.compilation.backends import VllmBackend
@@ -919,6 +990,13 @@ class CompilationConfig:
op,
)
+ def is_custom_op_enabled(self, op: str) -> bool:
+ if "all" in self.custom_ops:
+ return f"-{op}" not in self.custom_ops
+
+ assert "none" in self.custom_ops
+ return f"+{op}" in self.custom_ops
+
def adjust_cudagraph_sizes_for_spec_decode(
self, uniform_decode_query_len: int, tensor_parallel_size: int
):
diff --git a/vllm/config/device.py b/vllm/config/device.py
index e85cd15de8cf4..85662ddff76b7 100644
--- a/vllm/config/device.py
+++ b/vllm/config/device.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from dataclasses import field
from typing import Any, Literal
@@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
+from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@@ -45,7 +45,7 @@ class DeviceConfig:
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py
index dfd7ef63712a3..88f8b91c292bb 100644
--- a/vllm/config/kv_transfer.py
+++ b/vllm/config/kv_transfer.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
@@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
+from vllm.utils.hashing import safe_hash
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
@@ -79,7 +79,7 @@ class KVTransferConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
diff --git a/vllm/config/load.py b/vllm/config/load.py
index e424f8c5edb62..579a0bc31020e 100644
--- a/vllm/config/load.py
+++ b/vllm/config/load.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator
@@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.model_executor.model_loader import LoadFormats
@@ -104,7 +104,7 @@ class LoadConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("load_format", mode="after")
diff --git a/vllm/config/lora.py b/vllm/config/lora.py
index 072e0ec2104f5..6a8fd6359aadd 100644
--- a/vllm/config/lora.py
+++ b/vllm/config/lora.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from typing import TYPE_CHECKING, Any, Literal
import torch
@@ -11,6 +10,7 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -74,7 +74,7 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 49688e17cf932..21d602b30ac1a 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
@@ -33,12 +34,18 @@ from vllm.transformers_utils.config import (
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_mrope,
+ uses_xdrope_dim,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
-from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
+from vllm.transformers_utils.utils import (
+ is_gguf,
+ is_remote_gguf,
+ maybe_model_redirect,
+ split_remote_gguf,
+)
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
@@ -47,7 +54,6 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -55,7 +61,6 @@ if TYPE_CHECKING:
else:
PretrainedConfig = Any
- AttentionBackendEnum = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
@@ -293,9 +298,6 @@ class ModelConfig:
pooler_config: PoolerConfig | None = None
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
- override_pooler_config: dict | PoolerConfig | None = None
- """[DEPRECATED] Use `pooler_config` instead. This field will be removed in
- v0.12.0 or v1.0.0, whichever is sooner."""
# Multimodal config and init vars
multimodal_config: MultiModalConfig | None = None
@@ -342,7 +344,6 @@ class ModelConfig:
"logprobs_mode",
"disable_cascade_attn",
"skip_tokenizer_init",
- "enable_prompt_embeds",
"served_model_name",
"config_format",
"hf_token",
@@ -353,7 +354,6 @@ class ModelConfig:
"logits_processors",
"io_processor_plugin",
"pooler_config",
- "override_pooler_config",
"multimodal_config",
"limit_mm_per_prompt",
"media_io_kwargs",
@@ -439,12 +439,6 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
- if check_gguf_file(self.model):
- raise ValueError(
- "Using a tokenizer is mandatory when loading a GGUF model. "
- "Please specify the tokenizer path or name using the "
- "--tokenizer argument."
- )
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
@@ -585,16 +579,26 @@ class ModelConfig:
else: # task == "auto"
pass
else:
- debug_info = {
- "architectures": architectures,
- "is_generative_model": is_generative_model,
- "is_pooling_model": is_pooling_model,
- }
- raise AssertionError(
- "The model should be a generative or "
- "pooling model when task is set to "
- f"{self.task!r}. Found: {debug_info}"
- )
+ # Neither generative nor pooling model - try to convert if possible
+ if is_pooling_task:
+ runner = "pooling"
+ convert = _task_to_convert(self.task)
+ msg_hint = (
+ "Please replace this option with `--runner pooling "
+ f"--convert {convert}` to continue using this model "
+ "as a pooling model."
+ )
+ else:
+ debug_info = {
+ "architectures": architectures,
+ "is_generative_model": is_generative_model,
+ "is_pooling_model": is_pooling_model,
+ }
+ raise AssertionError(
+ "The model should be a generative or "
+ "pooling model when task is set to "
+ f"{self.task!r}. Found: {debug_info}"
+ )
self.runner = runner
self.convert = convert
@@ -631,18 +635,6 @@ class ModelConfig:
# Init pooler config if needed
if self.runner_type == "pooling":
- if self.override_pooler_config is not None:
- logger.warning_once(
- "`override_pooler_config` is deprecated and will be "
- "removed in v0.12.0 or v1.0.0, whichever is sooner. "
- "Please use `pooler_config` instead."
- )
-
- if isinstance(self.override_pooler_config, dict):
- self.pooler_config = PoolerConfig(**self.override_pooler_config)
- else:
- self.pooler_config = self.override_pooler_config
-
if self.pooler_config is None:
self.pooler_config = PoolerConfig()
@@ -700,6 +692,14 @@ class ModelConfig:
self.multimodal_config = MultiModalConfig(**mm_config_kwargs)
+ # Multimodal GGUF models must use original repo for mm processing
+ if is_gguf(self.tokenizer) and self.is_multimodal_model:
+ raise ValueError(
+ "Loading a multimodal GGUF model needs to use original "
+ "tokenizer. Please specify the unquantized hf model's "
+ "repo name or path using the --tokenizer argument."
+ )
+
if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
@@ -821,7 +821,10 @@ class ModelConfig:
self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self):
- return get_sentence_transformer_tokenizer_config(self.model, self.revision)
+ model = self.model
+ if is_remote_gguf(model):
+ model, _ = split_remote_gguf(model)
+ return get_sentence_transformer_tokenizer_config(model, self.revision)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
@@ -1605,6 +1608,10 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
+ @property
+ def uses_xdrope_dim(self) -> int:
+ return uses_xdrope_dim(self.hf_config)
+
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
@@ -1745,6 +1752,14 @@ class ModelConfig:
logger.info("Using max model len %s", max_model_len)
return max_model_len
+ def is_model_moe(
+ self,
+ ) -> bool:
+ return self.get_num_experts() > 1
+
+ def is_quantized(self) -> bool:
+ return getattr(self.hf_config, "quantization_config", None) is not None
+
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
"""
diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py
index 9f62b35ed515c..8a2936de96d6f 100644
--- a/vllm/config/multimodal.py
+++ b/vllm/config/multimodal.py
@@ -1,19 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from collections.abc import Mapping
-from typing import TYPE_CHECKING, Any, Literal, TypeAlias
+from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
-
-if TYPE_CHECKING:
- from vllm.attention.backends.registry import AttentionBackendEnum
-else:
- AttentionBackendEnum = Any
+from vllm.utils.hashing import safe_hash
@dataclass
@@ -170,8 +166,11 @@ class MultiModalConfig:
def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None:
- # We need to import the real type here (deferred to avoid circular import).
- from vllm.attention.backends.registry import AttentionBackendEnum
+ if isinstance(value, str) and value.upper() == "XFORMERS":
+ raise ValueError(
+ "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
+ "details). Please select a supported attention backend."
+ )
if value is None or isinstance(value, AttentionBackendEnum):
return value
@@ -210,7 +209,7 @@ class MultiModalConfig:
if self.mm_encoder_attn_backend is not None
else None
]
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def get_limit_per_prompt(self, modality: str) -> int:
diff --git a/vllm/config/observability.py b/vllm/config/observability.py
index 564c4f7aed419..ff35e12fe20ed 100644
--- a/vllm/config/observability.py
+++ b/vllm/config/observability.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from functools import cached_property
from typing import Any, Literal, cast
@@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
+from vllm.utils.hashing import safe_hash
DetailedTraceModules = Literal["model", "worker", "all"]
@@ -78,7 +78,7 @@ class ObservabilityConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("show_hidden_metrics_for_version")
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index 4b0236d8de3f5..4a8c8bc17cfc3 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -60,6 +60,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
+ use_async: bool = False
+ """
+ Whether to use non-blocking EPLB.
+ """
@config
@@ -137,22 +141,6 @@ class ParallelConfig:
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- num_redundant_experts: int | None = None
- """`num_redundant_experts` is deprecated and has been replaced with
- `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
- Please use `eplb_config.num_redundant_experts` instead."""
- eplb_window_size: int | None = None
- """`eplb_window_size` is deprecated and has been replaced with
- `eplb_config.window_size`. This will be removed in v0.12.0.
- Please use `eplb_config.window_size` instead."""
- eplb_step_interval: int | None = None
- """`eplb_step_interval` is deprecated and has been replaced with
- `eplb_config.step_interval`. This will be removed in v0.12.0.
- Please use `eplb_config.step_interval` instead."""
- eplb_log_balancedness: bool | None = None
- """`eplb_log_balancedness` is deprecated and has been replaced with
- `eplb_config.log_balancedness`. This will be removed in v0.12.0.
- Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model
@@ -250,9 +238,9 @@ class ParallelConfig:
cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
- and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
+ and `total_cp_world_size = pcp_world_size * dcp_world_size`.
store interleave_size tokens on total_cp_rank i,
- then store next interleave_size tokens on taotal_cp_rank i+1.
+ then store next interleave_size tokens on total_cp_rank i+1.
Interleave_size=1: token-level alignment, where token `i` is stored on
total_cp_rank `i % total_cp_world_size`.
Interleave_size=block_size: block-level alignment, where tokens are
@@ -512,40 +500,6 @@ class ParallelConfig:
"--all2all-backend command-line argument instead."
)
- # Forward deprecated fields to their new location
- if self.num_redundant_experts is not None:
- self.eplb_config.num_redundant_experts = self.num_redundant_experts
- logger.warning_once(
- "num_redundant_experts is deprecated and has been replaced "
- "with eplb_config.num_redundant_experts. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_window_size is not None:
- self.eplb_config.window_size = self.eplb_window_size
- logger.warning_once(
- "eplb_window_size is deprecated and has been replaced "
- "with eplb_config.window_size. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_step_interval is not None:
- self.eplb_config.step_interval = self.eplb_step_interval
- logger.warning_once(
- "eplb_step_interval is deprecated and has been replaced "
- "with eplb_config.step_interval. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_log_balancedness is not None:
- self.eplb_config.log_balancedness = self.eplb_log_balancedness
- logger.warning_once(
- "eplb_log_balancedness is deprecated and has been replaced "
- "with eplb_config.log_balancedness. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
-
# Continue with the rest of the initialization
self.world_size = (
self.pipeline_parallel_size
@@ -639,9 +593,10 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
- if self.distributed_executor_backend != "mp" and self.nnodes > 1:
+ if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
raise ValueError(
- "nnodes > 1 can only be set when distributed exectuor backend is mp."
+ "nnodes > 1 can only be set when distributed executor "
+ "backend is mp or uni."
)
@property
diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py
index 6bece8d0785bd..85950bbcd666f 100644
--- a/vllm/config/pooler.py
+++ b/vllm/config/pooler.py
@@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from typing import Any
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
@@ -102,7 +102,7 @@ class PoolerConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py
index b6078706daacf..ff1ac0e18f324 100644
--- a/vllm/config/scheduler.py
+++ b/vllm/config/scheduler.py
@@ -1,17 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
-from typing_extensions import Self, deprecated
+from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
@@ -178,7 +178,7 @@ class SchedulerConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
@@ -224,19 +224,6 @@ class SchedulerConfig:
self.verify_max_model_len(max_model_len)
- @property
- @deprecated(
- "`SchedulerConfig.chunked_prefill_enabled` has been renamed to "
- "`SchedulerConfig.enable_chunked_prefill`. "
- "The old name will be removed in v0.12."
- )
- def chunked_prefill_enabled(self) -> bool:
- return self.enable_chunked_prefill
-
- @chunked_prefill_enabled.setter
- def chunked_prefill_enabled(self, value: bool):
- self.enable_chunked_prefill = value
-
def verify_max_model_len(self, max_model_len: int) -> Self:
if (
self.max_num_batched_tokens < max_model_len
diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py
index d7c019c73d598..80d53a543f149 100644
--- a/vllm/config/speculative.py
+++ b/vllm/config/speculative.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
-import hashlib
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING:
@@ -162,7 +162,7 @@ class SpeculativeConfig:
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py
index 9530d3d81e15d..1b32675c3dbd2 100644
--- a/vllm/config/structured_outputs.py
+++ b/vllm/config/structured_outputs.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from typing import Any, Literal
from pydantic import model_validator
@@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
+from vllm.utils.hashing import safe_hash
StructuredOutputsBackend = Literal[
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
@@ -58,7 +58,7 @@ class StructuredOutputsConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index d64e315b4fe39..c576275e80fe3 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -3,15 +3,15 @@
import copy
import getpass
-import hashlib
import json
import os
import tempfile
import threading
import time
from contextlib import contextmanager
-from dataclasses import replace
+from dataclasses import is_dataclass, replace
from datetime import datetime
+from enum import IntEnum
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, get_args
@@ -25,6 +25,7 @@ from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
+from vllm.utils.hashing import safe_hash
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
@@ -57,6 +58,103 @@ else:
logger = init_logger(__name__)
+class OptimizationLevel(IntEnum):
+ """Optimization level enum."""
+
+ O0 = 0
+ """O0 : No optimization. no compilation, no cudagraphs, no other
+ optimization, just starting up immediately"""
+ O1 = 1
+ """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
+ cudagraphs"""
+ O2 = 2
+ """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
+ O3 = 3
+ """O3: Currently the same as -O2s."""
+
+
+IS_QUANTIZED = False
+IS_DENSE = False
+# The optimizations that depend on these properties currently set to False
+# in all cases.
+# if model_config is not None:
+# IS_QUANTIZED = lambda c: c.model_config.is_quantized()
+# IS_DENSE = lambda c: not c.model_config.is_model_moe()
+# See https://github.com/vllm-project/vllm/issues/25689.
+
+
+def enable_fusion(cfg: "VllmConfig") -> bool:
+ """Returns True if RMS norm or quant FP8 is enabled."""
+ return cfg.compilation_config.is_custom_op_enabled(
+ "rms_norm"
+ ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
+
+
+OPTIMIZATION_LEVEL_00 = {
+ "compilation_config": {
+ "pass_config": {
+ "enable_noop": False,
+ "enable_fusion": False,
+ "enable_fi_allreduce_fusion": False,
+ "enable_attn_fusion": False,
+ "enable_sequence_parallelism": False,
+ "enable_async_tp": False,
+ },
+ "cudagraph_mode": CUDAGraphMode.NONE,
+ "use_inductor_graph_partition": False,
+ },
+}
+OPTIMIZATION_LEVEL_01 = {
+ "compilation_config": {
+ "pass_config": {
+ "enable_noop": True,
+ "enable_fusion": enable_fusion,
+ "enable_fi_allreduce_fusion": False,
+ "enable_attn_fusion": False,
+ "enable_sequence_parallelism": False,
+ "enable_async_tp": False,
+ },
+ "cudagraph_mode": CUDAGraphMode.PIECEWISE,
+ "use_inductor_graph_partition": False,
+ },
+}
+OPTIMIZATION_LEVEL_02 = {
+ "compilation_config": {
+ "pass_config": {
+ "enable_noop": True,
+ "enable_fusion": enable_fusion,
+ "enable_fi_allreduce_fusion": False,
+ "enable_attn_fusion": IS_QUANTIZED,
+ "enable_sequence_parallelism": IS_DENSE,
+ "enable_async_tp": IS_DENSE,
+ },
+ "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
+ "use_inductor_graph_partition": False,
+ },
+}
+OPTIMIZATION_LEVEL_03 = {
+ "compilation_config": {
+ "pass_config": {
+ "enable_noop": True,
+ "enable_fusion": enable_fusion,
+ "enable_fi_allreduce_fusion": False,
+ "enable_attn_fusion": IS_QUANTIZED,
+ "enable_sequence_parallelism": IS_DENSE,
+ "enable_async_tp": IS_DENSE,
+ },
+ "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
+ "use_inductor_graph_partition": False,
+ },
+}
+
+OPTIMIZATION_LEVEL_TO_CONFIG = {
+ OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
+ OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
+ OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
+ OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
+}
+
+
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@@ -96,7 +194,7 @@ class VllmConfig:
"""`torch.compile` and cudagraph capture configuration for the model.
As a shorthand, one can append compilation arguments via
- -0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`).
+ -0.parameter=argument such as `-O.mode=3` (same as `-O='{"mode":3}'`).
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
@@ -116,6 +214,11 @@ class VllmConfig:
you are using. Contents must be hashable."""
instance_id: str = ""
"""The ID of the vLLM instance."""
+ optimization_level: OptimizationLevel = OptimizationLevel.O2
+ """The optimization level. These levels trade startup time cost for
+ performance, with -O0 having the best startup time and -O3 having the best
+ performance. -02 is used by defult. See OptimizationLevel for full
+ description."""
def compute_hash(self) -> str:
"""
@@ -193,7 +296,7 @@ class VllmConfig:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
- additional_config_hash = hashlib.md5(
+ additional_config_hash = safe_hash(
json.dumps(additional_config, sort_keys=True).encode(),
usedforsecurity=False,
).hexdigest()
@@ -204,9 +307,9 @@ class VllmConfig:
vllm_factors.append("None")
factors.append(vllm_factors)
- hash_str = hashlib.md5(
- str(factors).encode(), usedforsecurity=False
- ).hexdigest()[:10]
+ hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
+ :10
+ ]
return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int:
@@ -297,6 +400,50 @@ class VllmConfig:
return replace(self, model_config=model_config)
+ def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
+ """Set config attribute to default if not already set by user.
+
+ Args:
+ config_obj: Configuration object to update.
+ key: Attribute name.
+ value: Default value (static or callable).
+ """
+ if getattr(config_obj, key) is None:
+ # Some config values are known before initialization and are
+ # hard coded.
+ # Other values depend on the user given configuration, so they are
+ # implemented with lambda functions and decided at run time.
+ setattr(config_obj, key, value(self) if callable(value) else value)
+
+ def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
+ """Apply optimization level defaults using self as root.
+
+ Recursively applies values from defaults into nested config objects.
+ Only fields present in defaults are overwritten.
+
+ If the user configuration does not specify a value for a default field
+ and if the default field is still None after all user selections are
+ applied, then default values will be applied to the field. User speciied
+ fields will not be overridden by the default.
+
+ Args:
+ defaults: Dictionary of default values to apply.
+ """
+
+ def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
+ """Recursively apply defaults to config_obj, using self as root."""
+ for key, value in config_defaults.items():
+ if not hasattr(config_obj, key):
+ continue
+
+ current = getattr(config_obj, key)
+ if isinstance(value, dict) and is_dataclass(current):
+ apply_recursive(current, value)
+ else:
+ self._set_config_default(config_obj, key, value)
+
+ apply_recursive(self, defaults)
+
def _post_init_kv_transfer_config(self) -> None:
"""Update KVTransferConfig based on top-level configs in VllmConfig.
@@ -434,17 +581,47 @@ class VllmConfig:
"precision for chunked prefill triton kernels."
)
- # If the user does not explicitly set a compilation mode, then
- # we use the default mode. The default mode depends on other
- # settings (see the below code).
+ if (
+ self.optimization_level > OptimizationLevel.O0
+ and self.model_config is not None
+ and self.model_config.enforce_eager
+ ):
+ logger.warning("Enforce eager set, overriding optimization level to -O0")
+ self.optimization_level = OptimizationLevel.O0
+
+ if self.compilation_config.backend == "eager" or (
+ self.compilation_config.mode is not None
+ and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
+ ):
+ logger.warning(
+ "Inductor compilation was disabled by user settings,"
+ "Optimizations settings that are only active during"
+ "Inductor compilation will be ignored."
+ )
+
+ def has_blocked_weights():
+ if self.quant_config is not None:
+ if hasattr(self.quant_config, "weight_block_size"):
+ return self.quant_config.weight_block_size is not None
+ elif hasattr(self.quant_config, "has_blocked_weights"):
+ return self.quant_config.has_blocked_weights()
+ return False
+
+ # Enable quant_fp8 CUDA ops (TODO disable in follow up)
+ # On H100 the CUDA kernel is faster than
+ # native implementation
+ # https://github.com/vllm-project/vllm/issues/25094
+ if has_blocked_weights():
+ custom_ops = self.compilation_config.custom_ops
+ if "-quant_fp8" not in custom_ops:
+ custom_ops.append("+quant_fp8")
+
if self.compilation_config.mode is None:
- if self.model_config is not None and not self.model_config.enforce_eager:
+ if self.optimization_level > OptimizationLevel.O0:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else:
self.compilation_config.mode = CompilationMode.NONE
- # If user does not set custom ops via none or all set it here based on
- # compilation mode and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
@@ -454,23 +631,33 @@ class VllmConfig:
else:
self.compilation_config.custom_ops.append("all")
+ default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
+ self._apply_optimization_level_defaults(default_config)
+ if (
+ self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
+ and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
+ ):
+ logger.info(
+ "Cudagraph mode %s is not compatible with compilation mode %s."
+ "Overriding to NONE.",
+ self.compilation_config.cudagraph_mode,
+ self.compilation_config.mode,
+ )
+ self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
+
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True
+ if self.compilation_config.pass_config.enable_sequence_parallelism:
+ if "-rms_norm" in self.compilation_config.custom_ops:
+ logger.warning(
+ "RMS norm force disabled, sequence parallelism might break"
+ )
+ else:
+ self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode():
- # if cudagraph_mode is not explicitly set by users, set default
- # value
- if self.compilation_config.cudagraph_mode is None:
- if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
- # default to full and piecewise for most models
- self.compilation_config.cudagraph_mode = (
- CUDAGraphMode.FULL_AND_PIECEWISE
- )
- else:
- self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
-
# if cudagraph_mode has full cudagraphs, we need to check support
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py
index a7724a86cc6a5..fa99078e9ff0d 100644
--- a/vllm/distributed/device_communicators/tpu_communicator.py
+++ b/vllm/distributed/device_communicators/tpu_communicator.py
@@ -97,11 +97,3 @@ class TpuCommunicator(DeviceCommunicatorBase):
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)
-
-
-if USE_TPU_INFERENCE:
- from tpu_inference.distributed.device_communicators import (
- TpuCommunicator as TpuInferenceCommunicator,
- )
-
- TpuCommunicator = TpuInferenceCommunicator # type: ignore
diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py
new file mode 100644
index 0000000000000..e4b4fc92eeaaa
--- /dev/null
+++ b/vllm/distributed/eplb/async_worker.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+The async worker that transfers experts in the background.
+"""
+
+import asyncio
+import threading
+from typing import TYPE_CHECKING
+
+import torch
+from torch.distributed import ProcessGroup
+
+from vllm.distributed.parallel_state import get_ep_group
+from vllm.logger import init_logger
+
+from .rebalance_execute import transfer_layer
+
+if TYPE_CHECKING:
+ from .eplb_state import EplbState
+
+logger = init_logger(__name__)
+
+
+def start_async_worker(
+ state: "EplbState",
+ rank_mapping: dict[int, int] | None = None,
+ is_profile: bool = False,
+) -> threading.Thread:
+ ep_group = get_ep_group().device_group
+ rank = ep_group.rank()
+ device_index = state.cuda_device_index
+
+ def thread_target() -> None:
+ assert device_index is not None
+ torch.cuda.set_device(device_index)
+ cuda_stream = torch.cuda.Stream(device=device_index)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(
+ transfer_run_periodically(
+ state=state,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ rank_mapping=rank_mapping,
+ cuda_stream=cuda_stream,
+ )
+ )
+ except Exception as exc: # pragma: no cover - diagnostic path
+ logger.exception("async loop error (Rank %d): %s", rank, str(exc))
+ finally:
+ loop.close()
+
+ thread = threading.Thread(target=thread_target, daemon=True)
+ thread.start()
+ return thread
+
+
+async def transfer_run_periodically(
+ state: "EplbState",
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ rank_mapping: dict[int, int] | None = None,
+ cuda_stream: torch.cuda.Stream = None,
+) -> None:
+ while True:
+ await asyncio.to_thread(state.rearrange_event.wait)
+ logger.info("async worker woke up for EPLB transfer")
+
+ for model_state in state.model_states.values():
+ if not model_state.is_async_enabled:
+ continue
+ current_num_layers = model_state.model.num_moe_layers
+ while (
+ model_state.rebalanced
+ and model_state.layer_to_transfer < current_num_layers
+ ):
+ if (
+ not model_state.ep_buffer_ready
+ and model_state.rebalanced
+ and model_state.new_physical_to_logical_map is not None
+ ):
+ await asyncio.to_thread(model_state.buffer_lock.acquire)
+ try:
+ if model_state.layer_to_transfer >= current_num_layers:
+ break
+
+ (
+ model_state.is_unchanged,
+ model_state.is_received_locally,
+ model_state.experts_recv_loc,
+ ) = await transfer_layer(
+ old_global_expert_indices=model_state.physical_to_logical_map,
+ new_global_expert_indices=model_state.new_physical_to_logical_map,
+ expert_weights=model_state.model.expert_weights,
+ expert_weights_buffer=model_state.expert_buffer,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ layer=model_state.layer_to_transfer,
+ cuda_stream=cuda_stream,
+ rank_mapping=rank_mapping,
+ )
+ event = torch.cuda.Event(blocking=False)
+ cuda_stream.record_event(event)
+ model_state.buffer_ready_event = event
+ model_state.ep_buffer_ready = 1
+ finally:
+ model_state.buffer_lock.release()
+ else:
+ if not model_state.rebalanced:
+ break
+ await asyncio.sleep(0.001)
+
+ state.rearrange_event.clear()
diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py
index 526d3ceac7b8f..9f8798a96a2fc 100644
--- a/vllm/distributed/eplb/eplb_state.py
+++ b/vllm/distributed/eplb/eplb_state.py
@@ -26,6 +26,7 @@ MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""
+import threading
import time
from collections.abc import Sequence
from dataclasses import dataclass
@@ -43,8 +44,9 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
+from .async_worker import start_async_worker
from .rebalance_algo import rebalance_experts
-from .rebalance_execute import rearrange_expert_weights_inplace
+from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
logger = init_logger(__name__)
@@ -132,6 +134,74 @@ class EplbModelState:
"""
model_name: str
model: MixtureOfExperts
+ expert_buffer: list[torch.Tensor]
+ """
+ The buffer to store the expert weights during transfer.
+ """
+ buffer_lock: threading.Lock
+ """
+ The lock to protect the expert buffer.
+ """
+ buffer_ready_event: torch.cuda.Event | None
+ """
+ CUDA event recorded when the async worker finishes filling the buffer.
+ The main thread waits on this before consuming the buffer.
+ """
+ ep_buffer_ready: int
+ """
+ The flag indicates whether the expert buffer is ready for transfer.
+ 0 or 1.
+ """
+ layer_to_transfer: int
+ """
+ The layer index to transfer in async mode.
+ """
+ rebalanced: bool
+ """
+ The flag indicates whether the experts rebalance have been computed.
+ """
+ pending_global_ready_check: bool
+ """
+ Whether the async EPLB needs to poll peers for buffer readiness.
+ """
+ is_unchanged: list[bool]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ is_received_locally: list[bool]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ experts_recv_loc: dict[int, int]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ is_async_enabled: bool
+ """
+ The flag indicates whether the EPLB is running in async mode.
+ """
+ cuda_device_index: int | None
+ """
+ CUDA device index for the async EPLB worker thread.
+ """
+ new_physical_to_logical_map: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as physical_to_logical_map
+ """
+ new_logical_to_physical_map: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as logical_to_physical_map
+ """
+ new_logical_replica_count: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as logical_replica_count
+ """
class EplbState:
@@ -164,12 +234,31 @@ class EplbState:
Otherwise, the rearrangement will hang at collective
communication calls.
"""
- self.expert_rearrangement_step: int = 0
+ self.expert_rearrangement_step_interval: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
- self.expert_rearrangement_step_interval: int = 0
+ self.is_async: bool = False
+ """
+ The flag indicates whether the EPLB is running in async mode.
+ """
+ self.rearrange_event = threading.Event()
+ """
+ Event to signal when a new rearrangement is needed for the async thread.
+ """
+ self.async_worker: threading.Thread | None = None
+ """
+ Background thread handling async transfers.
+ """
+ self.cuda_device_index: int | None = None
+ """
+ CUDA device index for the async EPLB worker thread.
+ """
+ if self.device.type == "cuda":
+ self.cuda_device_index = self.device.index
+ if self.cuda_device_index is None and torch.cuda.is_available():
+ self.cuda_device_index = torch.cuda.current_device()
@staticmethod
def build_initial_global_physical_to_logical_map(
@@ -239,6 +328,8 @@ class EplbState:
Build the initial EPLB state.
"""
self.validate_ep_configuration(model)
+ self.is_async = self.parallel_config.eplb_config.use_async
+
physical_to_logical_map_list = (
EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
@@ -368,7 +459,12 @@ class EplbState:
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
+ else:
+ new_physical_to_logical_map = None
+ new_logical_to_physical_map = None
+
+ new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
@@ -385,15 +481,33 @@ class EplbState:
)
self.expert_rearrangement_step = 0
- self.model_states[model_config.compute_hash()] = EplbModelState(
- physical_to_logical_map,
- logical_to_physical_map,
- logical_replica_count,
- expert_load_pass,
- expert_load_window,
- model_config.model,
- model,
+ expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
+
+ model_state = EplbModelState(
+ physical_to_logical_map=physical_to_logical_map,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ expert_load_pass=expert_load_pass,
+ expert_load_window=expert_load_window,
+ model_name=model_config.model,
+ model=model,
+ expert_buffer=expert_buffer,
+ buffer_lock=threading.Lock(),
+ buffer_ready_event=None,
+ ep_buffer_ready=0,
+ layer_to_transfer=0,
+ rebalanced=False,
+ pending_global_ready_check=False,
+ is_unchanged=[],
+ is_received_locally=[],
+ experts_recv_loc={},
+ is_async_enabled=self.is_async,
+ cuda_device_index=self.cuda_device_index,
+ new_physical_to_logical_map=new_physical_to_logical_map,
+ new_logical_to_physical_map=new_logical_to_physical_map,
+ new_logical_replica_count=new_logical_replica_count,
)
+ self.model_states[model_config.compute_hash()] = model_state
def step(
self,
@@ -420,7 +534,7 @@ class EplbState:
- `max_tokens`: The maximum load across ranks.
- `balancedness`: The ratio of average load to maximum load.
"""
-
+ ep_group = get_ep_group().device_group
if is_profile:
self.rearrange(is_profile=True)
return
@@ -488,7 +602,49 @@ class EplbState:
# rearrangement step and perform rearrangement to ensure all ranks are
# performing collective communication.
self.expert_rearrangement_step += 1
+
+ if self.is_async:
+ for eplb_model_state in self.model_states.values():
+ if not eplb_model_state.is_async_enabled:
+ continue
+
+ all_ranks_buffer_ready = False
+ if eplb_model_state.pending_global_ready_check:
+ all_ranks_buffer_ready = self._all_ranks_buffer_ready(
+ eplb_model_state
+ )
+ if (
+ eplb_model_state.is_async_enabled
+ and eplb_model_state.ep_buffer_ready
+ and all_ranks_buffer_ready
+ ):
+ self.move_to_workspace(
+ model_state=eplb_model_state,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ )
+ if (
+ eplb_model_state.layer_to_transfer
+ >= eplb_model_state.model.num_moe_layers
+ ):
+ self.post_eplb(eplb_model_state, is_profile)
+ eplb_model_state.rebalanced = False
+ eplb_model_state.layer_to_transfer = 0
+ eplb_model_state.pending_global_ready_check = False
+ logger.info(
+ "finish async transfer for model %s rank %d layer %d",
+ eplb_model_state.model_name,
+ ep_group.rank(),
+ eplb_model_state.model.num_moe_layers,
+ )
+
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
+ if any(
+ eplb_model_state.is_async_enabled and eplb_model_state.rebalanced
+ for eplb_model_state in self.model_states.values()
+ ):
+ # Still performing asynchronous rearrangement
+ return
self.expert_rearrangement_step = 0
self.rearrange()
@@ -524,7 +680,11 @@ class EplbState:
if is_main_rank:
torch.cuda.synchronize()
time_start = time.perf_counter()
- logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
+ logger.info(
+ "Rearranging experts %s %s...",
+ "(async mode)" if self.is_async else "sync mode",
+ "(profile)" if is_profile else "",
+ )
if global_expert_loads is None:
# Map the physical expert load to global logical experts
@@ -593,6 +753,7 @@ class EplbState:
model = eplb_model_state.model
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
+
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
@@ -608,7 +769,7 @@ class EplbState:
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
- self.num_nodes = 1
+ num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
@@ -631,60 +792,216 @@ class EplbState:
num_gpus,
)
- # Update expert weights
- rearrange_expert_weights_inplace(
- eplb_model_state.physical_to_logical_map,
- new_physical_to_logical_map,
- eplb_model_state.model.expert_weights,
- ep_group,
- is_profile,
- rank_mapping,
- )
+ if not eplb_model_state.is_async_enabled or is_profile:
+ # Update expert weights
+ rearrange_expert_weights_inplace(
+ eplb_model_state.physical_to_logical_map,
+ new_physical_to_logical_map,
+ eplb_model_state.model.expert_weights,
+ ep_group,
+ is_profile,
+ rank_mapping,
+ )
- if not is_profile:
- if (
- eplb_model_state.physical_to_logical_map.shape[1]
- != new_physical_to_logical_map.shape[1]
- ):
- eplb_model_state.physical_to_logical_map = (
- new_physical_to_logical_map.to(
- eplb_model_state.physical_to_logical_map.device
+ if not is_profile:
+ if (
+ eplb_model_state.physical_to_logical_map.shape[1]
+ != new_physical_to_logical_map.shape[1]
+ ):
+ eplb_model_state.physical_to_logical_map = (
+ new_physical_to_logical_map.to(
+ eplb_model_state.physical_to_logical_map.device
+ )
)
+ else:
+ eplb_model_state.physical_to_logical_map.copy_(
+ new_physical_to_logical_map
+ )
+ max_physical_slots = new_logical_to_physical_map.shape[-1]
+ assert (
+ max_physical_slots
+ <= eplb_model_state.logical_to_physical_map.shape[-1]
)
- else:
- eplb_model_state.physical_to_logical_map.copy_(
- new_physical_to_logical_map
+ new_logical_to_physical_map = torch.nn.functional.pad(
+ new_logical_to_physical_map,
+ (
+ 0,
+ eplb_model_state.logical_to_physical_map.shape[-1]
+ - max_physical_slots,
+ ),
+ value=-1,
)
- max_physical_slots = new_logical_to_physical_map.shape[-1]
- assert (
- max_physical_slots
- <= eplb_model_state.logical_to_physical_map.shape[-1]
- )
- new_logical_to_physical_map = torch.nn.functional.pad(
+ eplb_model_state.logical_to_physical_map.copy_(
+ new_logical_to_physical_map
+ )
+ eplb_model_state.logical_replica_count.copy_(
+ new_logical_replica_count
+ )
+ if is_main_rank:
+ assert time_start is not None
+ torch.cuda.synchronize()
+ time_end = time.perf_counter()
+ logger.info(
+ "Rearranged experts%sin %.2f seconds.",
+ " (profile) " if is_profile else " ",
+ time_end - time_start,
+ )
+ else:
+ device = eplb_model_state.physical_to_logical_map.device
+ new_physical = new_physical_to_logical_map.to(device)
+ max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
+ padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
- (
- 0,
- eplb_model_state.logical_to_physical_map.shape[-1]
- - max_physical_slots,
- ),
+ (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
+ ).to(eplb_model_state.logical_to_physical_map.device)
+ new_replica = new_logical_replica_count.to(
+ eplb_model_state.logical_replica_count.device
)
- eplb_model_state.logical_to_physical_map.copy_(
- new_logical_to_physical_map
- )
- eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
- if is_main_rank:
- assert time_start is not None
- torch.cuda.synchronize()
- time_end = time.perf_counter()
- logger.info(
- "Rearranged experts%sin %.2f seconds.",
- " (profile) " if is_profile else " ",
- time_end - time_start,
- )
+ eplb_model_state.new_physical_to_logical_map = new_physical
+ eplb_model_state.new_logical_to_physical_map = padded_logical
+ eplb_model_state.new_logical_replica_count = new_replica
+
+ eplb_model_state.rebalanced = True
+ eplb_model_state.layer_to_transfer = 0
+ eplb_model_state.pending_global_ready_check = True
+
+ # Signal async thread to start transferring layers
+ if self.is_async and (not is_profile):
+ self.rearrange_event.set()
return None
+ def start_async_loop(
+ self,
+ rank_mapping: dict[int, int] | None = None,
+ is_profile: bool = False,
+ ):
+ if not self.is_async:
+ return
+ if self.async_worker is None:
+ self.async_worker = start_async_worker(
+ self,
+ rank_mapping=rank_mapping,
+ is_profile=is_profile,
+ )
+
+ def _update_layer_mapping_from_new(
+ self, model_state: EplbModelState, layer: int
+ ) -> None:
+ if (
+ model_state.new_physical_to_logical_map is None
+ or model_state.new_logical_to_physical_map is None
+ or model_state.new_logical_replica_count is None
+ ):
+ return
+
+ target_device = model_state.physical_to_logical_map.device
+ new_physical = model_state.new_physical_to_logical_map
+ if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
+ model_state.physical_to_logical_map = new_physical.to(target_device)
+ else:
+ model_state.physical_to_logical_map[layer].copy_(
+ new_physical[layer].to(target_device)
+ )
+
+ logical_device = model_state.logical_to_physical_map.device
+ new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
+ max_slots = model_state.logical_to_physical_map.shape[-1]
+ slot_delta = max_slots - new_logical.shape[-1]
+ if slot_delta > 0:
+ new_logical = torch.nn.functional.pad(
+ new_logical, (0, slot_delta), value=-1
+ )
+ model_state.logical_to_physical_map[layer].copy_(new_logical)
+
+ replica_device = model_state.logical_replica_count.device
+ model_state.logical_replica_count[layer].copy_(
+ model_state.new_logical_replica_count[layer].to(replica_device)
+ )
+
+ def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
+ parallel_state = get_ep_group()
+ cpu_group = getattr(parallel_state, "cpu_group", None)
+ if cpu_group is not None and cpu_group.size() > 1:
+ flag = torch.tensor(
+ (int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu"
+ )
+ all_reduce(flag, group=cpu_group)
+ return int(flag.item()) == cpu_group.size()
+
+ device_group = parallel_state.device_group
+ if device_group.size() <= 1:
+ return bool(model_state.ep_buffer_ready)
+
+ device = getattr(
+ parallel_state, "device", model_state.physical_to_logical_map.device
+ )
+ flag = torch.tensor(
+ (int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device
+ )
+ all_reduce(flag, group=device_group)
+ return int(flag.item()) == device_group.size()
+
+ def move_to_workspace(
+ self,
+ model_state: EplbModelState,
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ ):
+ if not model_state.buffer_lock.acquire(blocking=False):
+ return
+ try:
+ assert model_state.new_physical_to_logical_map is not None
+ device_index = model_state.cuda_device_index or self.cuda_device_index
+ if model_state.buffer_ready_event is not None and device_index is not None:
+ stream = torch.cuda.current_stream(device=device_index)
+ stream.wait_event(model_state.buffer_ready_event)
+ model_state.buffer_ready_event = None
+ move_from_buffer(
+ expert_weights=model_state.model.expert_weights[
+ model_state.layer_to_transfer
+ ],
+ expert_weights_buffer=model_state.expert_buffer,
+ is_unchanged=model_state.is_unchanged,
+ is_received_locally=model_state.is_received_locally,
+ experts_recv_loc=model_state.experts_recv_loc,
+ new_indices=model_state.new_physical_to_logical_map[
+ model_state.layer_to_transfer
+ ].tolist(),
+ ep_group=ep_group,
+ )
+ transferred_layer = model_state.layer_to_transfer
+ self._update_layer_mapping_from_new(model_state, transferred_layer)
+ # After the main thread consumes, advance layer_to_transfer
+ model_state.layer_to_transfer += 1
+ model_state.ep_buffer_ready = 0
+ logger.info(
+ "model %s successfully move_to_workspace layer %d",
+ model_state.model_name,
+ transferred_layer,
+ )
+ finally:
+ try:
+ model_state.buffer_lock.release()
+ except Exception as e:
+ logger.error(
+ "Rank %d: buffer_lock release failed in move_to_workspace: %s",
+ ep_group.rank(),
+ str(e),
+ )
+
+ def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None:
+ assert model_state.new_physical_to_logical_map is not None
+ assert model_state.new_logical_to_physical_map is not None
+ assert model_state.new_logical_replica_count is not None
+ if not is_profile:
+ for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
+ self._update_layer_mapping_from_new(model_state, layer_idx)
+ model_state.new_physical_to_logical_map = None
+ model_state.new_logical_to_physical_map = None
+ model_state.new_logical_replica_count = None
+
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py
index 5c1efbaf03bab..376dad8a72ef1 100644
--- a/vllm/distributed/eplb/rebalance_execute.py
+++ b/vllm/distributed/eplb/rebalance_execute.py
@@ -100,18 +100,19 @@ def get_ep_ranks_with_expert(
return ranks_to_send, ranks_to_recv_actual
-def shuffle_layer(
+def move_to_buffer(
num_local_experts: int,
- ep_rank: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
+ cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
-) -> None:
+) -> tuple[list[bool], list[bool], dict[int, int]]:
"""
Perform expert weights rearrangement of one layer.
"""
+ ep_rank = ep_group.rank()
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
@@ -137,7 +138,8 @@ def shuffle_layer(
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- buffer[dst].copy_(weight[src])
+ with torch.cuda.stream(cuda_stream):
+ buffer[dst].copy_(weight[src], non_blocking=True)
p2p_ops: list[P2POp] = []
@@ -225,25 +227,115 @@ def shuffle_layer(
]
# 4. Execute the P2P operations. The real communication happens here.
- if p2p_ops:
+ if p2p_ops and cuda_stream is not None:
+ with torch.cuda.stream(cuda_stream):
+ reqs = batch_isend_irecv(p2p_ops)
+ for req in reqs:
+ req.wait()
+ elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
+ # wait for the communication to finish
+ return is_unchanged, is_received_locally, experts_recv_loc
+
+
+def move_from_buffer(
+ expert_weights: Iterable[torch.Tensor],
+ expert_weights_buffer: list[torch.Tensor],
+ is_unchanged: list[bool],
+ is_received_locally: list[bool],
+ experts_recv_loc: dict[int, int],
+ new_indices: Sequence[int],
+ ep_group: ProcessGroup,
+) -> None:
+ ep_rank = ep_group.rank()
+ num_local_experts = len(is_unchanged)
+
+ local2global = partial(
+ idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
+ )
- # 5. Copy the weights from the buffer back to the original weights.
for dst in range(num_local_experts):
if is_unchanged[dst]:
continue
if is_received_locally[dst]:
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- weight[dst].copy_(buffer[dst])
+ weight[dst].copy_(buffer[dst], non_blocking=True)
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- weight[dst].copy_(buffer[src])
+ weight[dst].copy_(buffer[src], non_blocking=True)
+
+
+async def transfer_layer(
+ old_global_expert_indices: torch.Tensor,
+ new_global_expert_indices: torch.Tensor,
+ expert_weights: Sequence[Iterable[torch.Tensor]],
+ expert_weights_buffer: Sequence[torch.Tensor],
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ layer: int = 0,
+ cuda_stream: torch.cuda.Stream | None = None,
+ rank_mapping: dict[int, int] | None = None,
+) -> tuple[list[bool], list[bool], dict[int, int]]:
+ """
+ Rearranges the expert weights in place according to the new expert indices.
+
+ The value of the indices arguments are logical indices of the experts,
+ while keys are physical.
+
+ Args:
+ old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ expert_weights: A sequence of shape (num_moe_layers)(weight_count)
+ of tensors of shape (num_local_physical_experts, hidden_size_i).
+ For example, a linear layer may have up and down projection,
+ so weight_count = 2. Each weight's hidden size can be different.
+ ep_group: The device process group for expert parallelism.
+ is_profile (bool): If `True`, do not perform any actual weight copy.
+ This is used during profile run, where we only perform dummy
+ communications to reserve enough memory for the buffers.
+ """
+ ep_size = ep_group.size()
+ if rank_mapping is not None:
+ if len(rank_mapping) == ep_group.size():
+ # scale down
+ new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
+ new_global_expert_indices,
+ rank_mapping,
+ )
+ else:
+ # scale up
+ old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
+ old_global_expert_indices,
+ rank_mapping,
+ ep_group.size(),
+ )
+
+ assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
+ num_moe_layers, num_physical_experts = old_global_expert_indices.shape
+ assert len(expert_weights) == num_moe_layers
+ num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
+ assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
+ assert num_physical_experts == ep_size * num_local_physical_experts
+ # A buffer to hold the expert weights in one layer during the exchange.
+ # NOTE: Currently we assume the same weights across different layers
+ # have the same shape.
+
+ is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
+ num_local_experts=num_local_physical_experts,
+ old_indices=old_global_expert_indices[layer].tolist(),
+ new_indices=new_global_expert_indices[layer].tolist(),
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ cuda_stream=cuda_stream,
+ ep_group=ep_group,
+ )
+ return is_unchanged, is_received_locally, experts_recv_loc
def rearrange_expert_weights_inplace(
@@ -296,7 +388,6 @@ def rearrange_expert_weights_inplace(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
- ep_rank = ep_group.rank()
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
@@ -329,14 +420,24 @@ def rearrange_expert_weights_inplace(
torch.cuda.synchronize()
for layer in range(num_moe_layers):
- shuffle_layer(
- num_local_physical_experts,
- ep_rank,
- old_global_expert_indices_cpu[layer].tolist(),
- new_global_expert_indices_cpu[layer].tolist(),
- expert_weights[layer],
- expert_weights_buffer,
- ep_group,
+ is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
+ num_local_experts=num_local_physical_experts,
+ old_indices=old_global_expert_indices_cpu[layer].tolist(),
+ new_indices=new_global_expert_indices_cpu[layer].tolist(),
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ cuda_stream=None,
+ ep_group=ep_group,
+ )
+
+ move_from_buffer(
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ is_unchanged=is_unchanged,
+ is_received_locally=is_received_locally,
+ experts_recv_loc=experts_recv_loc,
+ new_indices=new_global_expert_indices[layer].tolist(),
+ ep_group=ep_group,
)
@@ -428,4 +529,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices
-__all__ = ["rearrange_expert_weights_inplace"]
+__all__ = ["transfer_layer", "move_from_buffer"]
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
index 74f09278b7bb1..cac45425bb7aa 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
@@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch
+from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return
def register_cross_layers_kv_cache(
- self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
+ self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
"""
Initialize with a single KV cache tensor used by all layers.
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
index 9cd7d93c92fa3..e9b2bd392b0ef 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
@@ -45,7 +46,6 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
# This connector doesn't save KV cache (benchmarking only)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
index 0c24a53fb754b..30da424ddcca0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
@@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@@ -17,7 +18,6 @@ from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
index 94572b02fa872..15ac5b049fce9 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_manager import KVCacheManager
@@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""Start saving the a layer of KV cache from vLLM's paged buffer
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
index d1d3e475cc889..a4bddf5e03166 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
@@ -10,6 +10,7 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
index c9d08e9b78ed0..f47e8ca7e6c50 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
for c in self._connectors:
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index 7c0911240493c..24c8d32dafedc 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -4,7 +4,6 @@ import contextlib
import copy
import logging
import math
-import os
import queue
import threading
import time
@@ -21,7 +20,7 @@ import torch
import zmq
from vllm import envs
-from vllm.attention import AttentionBackend
+from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
@@ -52,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@@ -309,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""
@@ -810,9 +808,6 @@ class NixlConnectorWorker:
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
- # TODO temporary, once nixl allows for telemetry flag in config
- # (next release), we can remove this env var.
- os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
@@ -828,10 +823,11 @@ class NixlConnectorWorker:
if nixl_agent_config is None:
config = None
else:
+ # Enable telemetry by default for NIXL 0.7.1 and above.
config = (
- nixl_agent_config(backends=self.nixl_backends)
+ nixl_agent_config(backends=self.nixl_backends, capture_telemetry=True)
if len(non_ucx_backends) > 0
- else nixl_agent_config(num_threads=num_threads)
+ else nixl_agent_config(num_threads=num_threads, capture_telemetry=True)
)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
@@ -1835,35 +1831,55 @@ class NixlConnectorWorker:
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = False
- for handle, _xfer_stime in handles:
- xfer_state = self.nixl_wrapper.check_xfer_state(handle)
- if xfer_state == "DONE":
- # Get telemetry from NIXL
- res = self.nixl_wrapper.get_xfer_telemetry(handle)
- self.xfer_stats.record_transfer(res)
- self.nixl_wrapper.release_xfer_handle(handle)
- elif xfer_state == "PROC":
- in_progress = True
- continue
- else:
- # transfer failed - mark blocks as invalid
- logger.error(
- "NIXL transfer failed for request %s with state %s. "
+ for handle, xfer_start_time in handles:
+ try:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # Get telemetry from NIXL
+ res = self.nixl_wrapper.get_xfer_telemetry(handle)
+ self.xfer_stats.record_transfer(res)
+ self.nixl_wrapper.release_xfer_handle(handle)
+ elif xfer_state == "PROC":
+ in_progress = True
+ continue
+ else:
+ logger.error(
+ "NIXL transfer failed for request %s with state "
+ "%s. Marking blocks as invalid.",
+ req_id,
+ xfer_state,
+ )
+ self._handle_failed_transfer(req_id, handle)
+ in_progress = False
+ except Exception:
+ logger.exception(
+ "NIXL transfer exception for request %s. "
"Marking blocks as invalid.",
req_id,
- xfer_state,
)
- # mark all (logical)blocks for this request as invalid
- if meta := self._recving_metadata.pop(req_id, None):
- self._invalid_block_ids.update(meta.local_block_ids)
- self._recving_metadata.pop(req_id, None)
- self.nixl_wrapper.release_xfer_handle(handle)
- self.xfer_stats.record_failed_transfer()
+ self._handle_failed_transfer(req_id, handle)
+ in_progress = False
+
if not in_progress:
done_req_ids.add(req_id)
del transfers[req_id]
return done_req_ids
+ def _handle_failed_transfer(self, req_id: str, handle: int):
+ """
+ Handle a failed transfer by marking all (logical) blocks as invalid and
+ recording the failure.
+
+ Args:
+ req_id: The request ID.
+ handle: The transfer handle.
+ """
+ if meta := self._recving_metadata.pop(req_id, None):
+ self._invalid_block_ids.update(meta.local_block_ids)
+ self._recving_metadata.pop(req_id, None)
+ self.nixl_wrapper.release_xfer_handle(handle)
+ self.xfer_stats.record_failed_transfer()
+
def start_load_kv(self, metadata: NixlConnectorMetadata):
"""
Start loading by triggering non-blocking nixl_xfer.
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
index 8cd09014cab11..0ad9d4ae1b39f 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
@@ -8,7 +8,8 @@ from typing import Any, ClassVar
import torch
-from vllm.attention import Attention, AttentionBackend, AttentionMetadata
+from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
+from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
index a124a0d519db8..8f3a62d7bcdb0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
index 016d1d45b3593..ed641cfc43ddd 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
@@ -8,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
@@ -15,11 +15,11 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole,
)
from vllm.logger import init_logger
+from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
- attn_metadata: "AttentionMetadata",
+ attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
@@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
- input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
+ input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index f81612fd1f4a3..52b433cfaf1bd 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -41,7 +41,6 @@ import torch.distributed
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
from torch.distributed import Backend, ProcessGroup
-from typing_extensions import deprecated
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
@@ -51,6 +50,7 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.network_utils import get_distributed_init_method
+from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import (
direct_register_custom_op,
supports_custom_op,
@@ -329,7 +329,8 @@ class GroupCoordinator:
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
- cpu_group = torch.distributed.new_group(ranks, backend="gloo")
+ with suppress_stdout():
+ cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
@@ -1076,15 +1077,6 @@ def get_tp_group() -> GroupCoordinator:
return _TP
-@deprecated(
- "`get_tensor_model_parallel_group` has been replaced with "
- "`get_tp_group` and may be removed after v0.12. Please use "
- "`get_tp_group` instead."
-)
-def get_tensor_model_parallel_group():
- return get_tp_group()
-
-
_DCP: GroupCoordinator | None = None
@@ -1128,15 +1120,6 @@ def get_pcp_group() -> GroupCoordinator:
return _PCP
-@deprecated(
- "`get_pipeline_model_parallel_group` has been replaced with "
- "`get_pp_group` and may be removed in v0.12. Please use "
- "`get_pp_group` instead."
-)
-def get_pipeline_model_parallel_group():
- return get_pp_group()
-
-
@contextmanager
def graph_capture(device: torch.device):
"""
diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py
index debf69c49b7d9..242ce393e4dc8 100644
--- a/vllm/distributed/utils.py
+++ b/vllm/distributed/utils.py
@@ -30,6 +30,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_tcp_uri
+from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -427,33 +428,34 @@ def init_gloo_process_group(
Stateless init ProcessGroup with gloo backend compatible with
different torch versions.
"""
- if is_torch_equal_or_newer("2.6"):
- pg = ProcessGroup(
- prefix_store,
- group_rank,
- group_size,
- )
- else:
- options = ProcessGroup.Options(backend="gloo")
- pg = ProcessGroup(
- prefix_store,
- group_rank,
- group_size,
- options,
- )
- from torch.distributed.distributed_c10d import ProcessGroupGloo
+ with suppress_stdout():
+ if is_torch_equal_or_newer("2.6"):
+ pg = ProcessGroup(
+ prefix_store,
+ group_rank,
+ group_size,
+ )
+ else:
+ options = ProcessGroup.Options(backend="gloo")
+ pg = ProcessGroup(
+ prefix_store,
+ group_rank,
+ group_size,
+ options,
+ )
+ from torch.distributed.distributed_c10d import ProcessGroupGloo
- backend_class = ProcessGroupGloo(
- prefix_store, group_rank, group_size, timeout=timeout
- )
- backend_type = ProcessGroup.BackendType.GLOO
- device = torch.device("cpu")
- if is_torch_equal_or_newer("2.6"):
- # _set_default_backend is supported in torch >= 2.6
- pg._set_default_backend(backend_type)
- backend_class._set_sequence_number_for_group()
+ backend_class = ProcessGroupGloo(
+ prefix_store, group_rank, group_size, timeout=timeout
+ )
+ backend_type = ProcessGroup.BackendType.GLOO
+ device = torch.device("cpu")
+ if is_torch_equal_or_newer("2.6"):
+ # _set_default_backend is supported in torch >= 2.6
+ pg._set_default_backend(backend_type)
+ backend_class._set_sequence_number_for_group()
- pg._register_backend(device, backend_type, backend_class)
+ pg._register_backend(device, backend_type, backend_class)
return pg
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index b7c8f56e18c52..e4c9a82d25223 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -29,7 +29,7 @@ import regex as re
import torch
from pydantic import TypeAdapter, ValidationError
from pydantic.fields import FieldInfo
-from typing_extensions import TypeIs, deprecated
+from typing_extensions import TypeIs
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
@@ -77,7 +77,8 @@ from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
from vllm.config.utils import get_field
-from vllm.logger import init_logger
+from vllm.config.vllm import OptimizationLevel
+from vllm.logger import init_logger, suppress_logging
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
@@ -86,7 +87,7 @@ from vllm.transformers_utils.config import (
is_interleaved,
maybe_override_with_speculators,
)
-from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage
+from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
@@ -247,11 +248,13 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
default = field.default
# Handle pydantic.Field defaults
if isinstance(default, FieldInfo):
- default = (
- default.default
- if default.default_factory is None
- else default.default_factory()
- )
+ if default.default_factory is None:
+ default = default.default
+ else:
+ # VllmConfig's Fields have default_factory set to config classes.
+ # These could emit logs on init, which would be confusing.
+ with suppress_logging():
+ default = default.default_factory()
elif field.default_factory is not MISSING:
default = field.default_factory()
@@ -502,11 +505,6 @@ class EngineArgs:
)
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
reasoning_parser_plugin: str | None = None
- # Deprecated guided decoding fields
- guided_decoding_backend: str | None = None
- guided_decoding_disable_fallback: bool | None = None
- guided_decoding_disable_any_whitespace: bool | None = None
- guided_decoding_disable_additional_properties: bool | None = None
logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
@@ -523,9 +521,6 @@ class EngineArgs:
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
- override_pooler_config: dict | PoolerConfig | None = (
- ModelConfig.override_pooler_config
- )
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
@@ -566,6 +561,7 @@ class EngineArgs:
stream_interval: int = SchedulerConfig.stream_interval
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
+ optimization_level: OptimizationLevel = VllmConfig.optimization_level
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
kv_offloading_backend: KVOffloadingBackend | None = (
@@ -662,11 +658,6 @@ class EngineArgs:
)
model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
- model_group.add_argument(
- "--override-pooler-config",
- **model_kwargs["override_pooler_config"],
- deprecated=True,
- )
model_group.add_argument(
"--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
)
@@ -725,19 +716,6 @@ class EngineArgs:
"--reasoning-parser-plugin",
**structured_outputs_kwargs["reasoning_parser_plugin"],
)
- # Deprecated guided decoding arguments
- for arg, type in [
- ("--guided-decoding-backend", str),
- ("--guided-decoding-disable-fallback", bool),
- ("--guided-decoding-disable-any-whitespace", bool),
- ("--guided-decoding-disable-additional-properties", bool),
- ]:
- structured_outputs_group.add_argument(
- arg,
- type=type,
- help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
- deprecated=True,
- )
# Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig)
@@ -855,30 +833,6 @@ class EngineArgs:
"--expert-placement-strategy",
**parallel_kwargs["expert_placement_strategy"],
)
- parallel_group.add_argument(
- "--num-redundant-experts",
- type=int,
- help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-window-size",
- type=int,
- help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-step-interval",
- type=int,
- help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-log-balancedness",
- action=argparse.BooleanOptionalAction,
- help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
- deprecated=True,
- )
parallel_group.add_argument(
"--max-parallel-loading-workers",
@@ -920,7 +874,11 @@ class EngineArgs:
"--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
)
cache_group.add_argument(
- "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
+ "--enable-prefix-caching",
+ **{
+ **cache_kwargs["enable_prefix_caching"],
+ "default": None,
+ },
)
cache_group.add_argument(
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
@@ -1158,6 +1116,10 @@ class EngineArgs:
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
)
+ vllm_group.add_argument(
+ "--optimization-level", **vllm_kwargs["optimization_level"]
+ )
+
# Other arguments
parser.add_argument(
"--disable-log-stats",
@@ -1184,8 +1146,8 @@ class EngineArgs:
return engine_args
def create_model_config(self) -> ModelConfig:
- # gguf file needs a specific model loader and doesn't use hf_repo
- if check_gguf_file(self.model):
+ # gguf file needs a specific model loader
+ if is_gguf(self.model):
self.quantization = self.load_format = "gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless
@@ -1279,7 +1241,6 @@ class EngineArgs:
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config,
- override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
@@ -1612,6 +1573,12 @@ class EngineArgs:
model_config.skip_tokenizer_init = True
logger.info("Skipping tokenizer initialization for tokens-only mode.")
+ if self.async_scheduling and not self.disable_nccl_for_dp_synchronization:
+ logger.info(
+ "Disabling NCCL for DP synchronization when using async scheduling."
+ )
+ self.disable_nccl_for_dp_synchronization = True
+
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
@@ -1736,21 +1703,6 @@ class EngineArgs:
self.reasoning_parser_plugin
)
- # Forward the deprecated CLI args to the StructuredOutputsConfig
- so_config = self.structured_outputs_config
- if self.guided_decoding_backend is not None:
- so_config.guided_decoding_backend = self.guided_decoding_backend
- if self.guided_decoding_disable_fallback is not None:
- so_config.disable_fallback = self.guided_decoding_disable_fallback
- if self.guided_decoding_disable_any_whitespace is not None:
- so_config.disable_any_whitespace = (
- self.guided_decoding_disable_any_whitespace
- )
- if self.guided_decoding_disable_additional_properties is not None:
- so_config.disable_additional_properties = (
- self.guided_decoding_disable_additional_properties
- )
-
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
otlp_traces_endpoint=self.otlp_traces_endpoint,
@@ -1787,7 +1739,6 @@ class EngineArgs:
compilation_config.max_cudagraph_capture_size = (
self.max_cudagraph_capture_size
)
-
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@@ -1804,6 +1755,7 @@ class EngineArgs:
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
additional_config=self.additional_config,
+ optimization_level=self.optimization_level,
)
return config
@@ -1861,9 +1813,11 @@ class EngineArgs:
if model_config.runner_type != "pooling":
default_chunked_prefill = True
- # Disable prefix caching default for hybrid models
- # since the feature is still experimental.
- default_prefix_caching = not model_config.is_hybrid
+ # Disable prefix caching default for hybrid models and mamba-only
+ # models since the feature is still experimental.
+ default_prefix_caching = not (
+ model_config.is_hybrid or model_config.is_attention_free
+ )
else:
assert model_config.pooler_config is not None
@@ -2090,24 +2044,6 @@ class AsyncEngineArgs(EngineArgs):
enable_log_requests: bool = False
- @property
- @deprecated(
- "`disable_log_requests` is deprecated and has been replaced with "
- "`enable_log_requests`. This will be removed in v0.12.0. Please use "
- "`enable_log_requests` instead."
- )
- def disable_log_requests(self) -> bool:
- return not self.enable_log_requests
-
- @disable_log_requests.setter
- @deprecated(
- "`disable_log_requests` is deprecated and has been replaced with "
- "`enable_log_requests`. This will be removed in v0.12.0. Please use "
- "`enable_log_requests` instead."
- )
- def disable_log_requests(self, value: bool):
- self.enable_log_requests = not value
-
@staticmethod
def add_cli_args(
parser: FlexibleArgumentParser, async_args_only: bool = False
diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py
index 5e3374f9f6a10..6b3ee042daf3e 100644
--- a/vllm/engine/protocol.py
+++ b/vllm/engine/protocol.py
@@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import enum
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs.data import PromptType
-from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
@@ -19,13 +17,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor
-logger = init_logger(__name__)
-
-
-class Device(enum.Enum):
- GPU = enum.auto()
- CPU = enum.auto()
-
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 848916dbd8763..f6ee746789981 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -174,9 +174,6 @@ class LLM:
For example, for Phi-3-Vision: `{"num_crops": 4}`.
pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
- override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
- argument is deprecated and will be removed in v0.12.0 or v1.0.0,
- whichever is sooner.
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
@@ -214,7 +211,6 @@ class LLM:
hf_overrides: HfOverrides | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
pooler_config: PoolerConfig | None = None,
- override_pooler_config: PoolerConfig | None = None,
structured_outputs_config: dict[str, Any]
| StructuredOutputsConfig
| None = None,
@@ -330,7 +326,6 @@ class LLM:
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config,
- override_pooler_config=override_pooler_config,
structured_outputs_config=structured_outputs_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
@@ -410,6 +405,9 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
+ If provided, must be a list of integers matching the length
+ of `prompts`, where each priority value corresponds to the prompt
+ at the same index.
Returns:
A list of `RequestOutput` objects containing the
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index b352c3ad01db0..688ea9697d9d6 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -29,7 +29,6 @@ from openai.types.responses import (
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponsePrompt,
- ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseStatus,
@@ -304,9 +303,7 @@ def get_logits_processors(
return None
-ResponseInputOutputItem: TypeAlias = (
- ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
-)
+ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
class ResponsesRequest(OpenAIBaseModel):
@@ -559,9 +556,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
) = "none"
reasoning_effort: Literal["low", "medium", "high"] | None = None
include_reasoning: bool = True
+ parallel_tool_calls: bool | None = True
- # NOTE this will be ignored by vLLM -- the model determines the behavior
- parallel_tool_calls: bool | None = False
+ # NOTE this will be ignored by vLLM
user: str | None = None
# --8<-- [start:chat-completion-sampling-params]
@@ -652,62 +649,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
- guided_json: str | dict | BaseModel | None = Field(
- default=None,
- description=(
- "`guided_json` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `json` to `structured_outputs` instead."
- ),
- )
- guided_regex: str | None = Field(
- default=None,
- description=(
- "`guided_regex` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `regex` to `structured_outputs` instead."
- ),
- )
- guided_choice: list[str] | None = Field(
- default=None,
- description=(
- "`guided_choice` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `choice` to `structured_outputs` instead."
- ),
- )
- guided_grammar: str | None = Field(
- default=None,
- description=(
- "`guided_grammar` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `grammar` to `structured_outputs` instead."
- ),
- )
- structural_tag: str | None = Field(
- default=None,
- description=(
- "`structural_tag` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `structural_tag` to `structured_outputs` instead."
- ),
- )
- guided_decoding_backend: str | None = Field(
- default=None,
- description=(
- "`guided_decoding_backend` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please remove it from your request."
- ),
- )
- guided_whitespace_pattern: str | None = Field(
- default=None,
- description=(
- "`guided_whitespace_pattern` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `whitespace_pattern` to `structured_outputs` instead."
- ),
- )
priority: int = Field(
default=0,
description=(
@@ -717,7 +658,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -841,20 +782,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
- # Forward deprecated guided_* parameters to structured_outputs
- if self.structured_outputs is None:
- kwargs = dict[str, Any](
- json=self.guided_json,
- regex=self.guided_regex,
- choice=self.guided_choice,
- grammar=self.guided_grammar,
- whitespace_pattern=self.guided_whitespace_pattern,
- structural_tag=self.structural_tag,
- )
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if len(kwargs) > 0:
- self.structured_outputs = StructuredOutputsParams(**kwargs)
-
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
@@ -863,24 +790,23 @@ class ChatCompletionRequest(OpenAIBaseModel):
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
- if response_format is not None:
- if response_format.type == "json_object":
- self.structured_outputs.json_object = True
- elif response_format.type == "json_schema":
- json_schema = response_format.json_schema
- assert json_schema is not None
- self.structured_outputs.json = json_schema.json_schema
- elif response_format.type == "structural_tag":
- structural_tag = response_format
- assert structural_tag is not None and isinstance(
- structural_tag,
- (
- LegacyStructuralTagResponseFormat,
- StructuralTagResponseFormat,
- ),
- )
- s_tag_obj = structural_tag.model_dump(by_alias=True)
- self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
+ if response_format.type == "json_object":
+ self.structured_outputs.json_object = True
+ elif response_format.type == "json_schema":
+ json_schema = response_format.json_schema
+ assert json_schema is not None
+ self.structured_outputs.json = json_schema.json_schema
+ elif response_format.type == "structural_tag":
+ structural_tag = response_format
+ assert structural_tag is not None and isinstance(
+ structural_tag,
+ (
+ LegacyStructuralTagResponseFormat,
+ StructuralTagResponseFormat,
+ ),
+ )
+ s_tag_obj = structural_tag.model_dump(by_alias=True)
+ self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@@ -1140,58 +1066,6 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
- guided_json: str | dict | BaseModel | None = Field(
- default=None,
- description=(
- "`guided_json` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `json` to `structured_outputs` instead."
- ),
- )
- guided_regex: str | None = Field(
- default=None,
- description=(
- "`guided_regex` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `regex` to `structured_outputs` instead."
- ),
- )
- guided_choice: list[str] | None = Field(
- default=None,
- description=(
- "`guided_choice` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `choice` to `structured_outputs` instead."
- ),
- )
- guided_grammar: str | None = Field(
- default=None,
- description=(
- "`guided_grammar` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `grammar` to `structured_outputs` instead."
- ),
- )
- structural_tag: str | None = Field(
- default=None,
- description=("If specified, the output will follow the structural tag schema."),
- )
- guided_decoding_backend: str | None = Field(
- default=None,
- description=(
- "`guided_decoding_backend` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please remove it from your request."
- ),
- )
- guided_whitespace_pattern: str | None = Field(
- default=None,
- description=(
- "`guided_whitespace_pattern` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `whitespace_pattern` to `structured_outputs` instead."
- ),
- )
priority: int = Field(
default=0,
description=(
@@ -1201,7 +1075,7 @@ class CompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -1336,35 +1210,31 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
- guided_json_object = None
- if self.response_format is not None:
- if self.response_format.type == "json_object":
- guided_json_object = True
- elif self.response_format.type == "json_schema":
- json_schema = self.response_format.json_schema
+ response_format = self.response_format
+ if response_format is not None:
+ # If structured outputs wasn't already enabled,
+ # we must enable it for these features to work
+ if self.structured_outputs is None:
+ self.structured_outputs = StructuredOutputsParams()
+
+ # Set structured output params for response format
+ if response_format.type == "json_object":
+ self.structured_outputs.json_object = True
+ elif response_format.type == "json_schema":
+ json_schema = response_format.json_schema
assert json_schema is not None
- self.guided_json = json_schema.json_schema
- elif self.response_format.type == "structural_tag":
- structural_tag = self.response_format
+ self.structured_outputs.json = json_schema.json_schema
+ elif response_format.type == "structural_tag":
+ structural_tag = response_format
assert structural_tag is not None and isinstance(
- structural_tag, StructuralTagResponseFormat
+ structural_tag,
+ (
+ LegacyStructuralTagResponseFormat,
+ StructuralTagResponseFormat,
+ ),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
- self.structural_tag = json.dumps(s_tag_obj)
-
- # Forward deprecated guided_* parameters to structured_outputs
- if self.structured_outputs is None:
- kwargs = dict[str, Any](
- json=self.guided_json,
- json_object=guided_json_object,
- regex=self.guided_regex,
- choice=self.guided_choice,
- grammar=self.guided_grammar,
- whitespace_pattern=self.guided_whitespace_pattern,
- )
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if len(kwargs) > 0:
- self.structured_outputs = StructuredOutputsParams(**kwargs)
+ self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@@ -1502,7 +1372,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -1597,7 +1467,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -2019,7 +1889,7 @@ class ClassificationCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -2110,7 +1980,7 @@ class ClassificationChatRequest(OpenAIBaseModel):
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -3221,7 +3091,7 @@ class TranslationResponseVerbose(OpenAIBaseModel):
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -3278,7 +3148,7 @@ class GenerateResponseChoice(BaseModel):
class GenerateResponse(BaseModel):
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 6cc685acd6728..9a7051e0920af 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
+from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
@@ -273,6 +274,11 @@ class OpenAIServingChat(OpenAIServing):
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i])
+ # If we are creating sub requests for multiple prompts, ensure that they
+ # have unique request ids.
+ sub_request_id = (
+ request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
+ )
if self.default_sampling_params is None:
self.default_sampling_params = {}
@@ -301,7 +307,7 @@ class OpenAIServingChat(OpenAIServing):
)
self._log_inputs(
- request_id,
+ sub_request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
@@ -316,14 +322,14 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
- request_id=request_id,
+ request_id=sub_request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
- request_id,
+ sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
@@ -334,7 +340,7 @@ class OpenAIServingChat(OpenAIServing):
generator = self.engine_client.generate(
engine_request,
sampling_params,
- request_id,
+ sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
@@ -1201,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing):
finish_reason_sent[i] = True
+ choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
@@ -1526,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing):
as_list(output.token_ids) if request.return_token_ids else None
),
)
+ choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
choices.append(choice_data)
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 7dab5dbacd28c..d9feee917ff4e 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -296,11 +296,7 @@ class OpenAIServing:
parser = None
if not enable_auto_tools or tool_parser_name is None:
return parser
- logger.info(
- '"auto" tool choice has been enabled please note that while'
- " the parallel_tool_calls client option is preset for "
- "compatibility reasons, it will be ignored."
- )
+ logger.info('"auto" tool choice has been enabled.')
try:
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
@@ -1242,16 +1238,19 @@ class OpenAIServing:
):
prompt_text, _, _ = self._get_prompt_components(request_prompt)
orig_priority = priority
+ sub_request = 0
while True:
+ # Ensure that each sub-request has a unique request id.
+ sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
- request_id,
+ sub_request_id,
request_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs(
- request_id,
+ sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
@@ -1262,7 +1261,7 @@ class OpenAIServing:
generator = self.engine_client.generate(
engine_request,
sampling_params,
- request_id,
+ sub_request_id,
lora_request=lora_request,
priority=priority,
prompt_text=prompt_text,
@@ -1295,6 +1294,7 @@ class OpenAIServing:
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
# OPTIMIZATION
priority = orig_priority - 1
+ sub_request += 1
def _get_prompt_components(
self,
@@ -1345,11 +1345,12 @@ class OpenAIServing:
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
- default = default or random_uuid()
- if raw_request is None:
- return default
+ if raw_request is not None and (
+ (req_id := raw_request.headers.get("X-Request-Id")) is not None
+ ):
+ return req_id
- return raw_request.headers.get("X-Request-Id", default)
+ return random_uuid() if default is None else default
@staticmethod
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py
index 06efb43ecb7b8..f546dbda7fef5 100644
--- a/vllm/entrypoints/openai/serving_responses.py
+++ b/vllm/entrypoints/openai/serving_responses.py
@@ -94,7 +94,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import (
- construct_chat_message_with_tool_call,
+ construct_input_messages,
convert_tool_responses_to_completions_format,
extract_tool_types,
)
@@ -504,7 +504,12 @@ class OpenAIServingResponses(OpenAIServing):
for tool in request.tools
]
# Construct the input messages.
- messages = self._construct_input_messages(request, prev_response)
+ messages = construct_input_messages(
+ request_instructions=request.instructions,
+ request_input=request.input,
+ prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
+ prev_response_output=prev_response.output if prev_response else None,
+ )
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
@@ -869,47 +874,6 @@ class OpenAIServingResponses(OpenAIServing):
output_items.extend(last_items)
return output_items
- def _construct_input_messages(
- self,
- request: ResponsesRequest,
- prev_response: ResponsesResponse | None = None,
- ) -> list[ChatCompletionMessageParam]:
- messages: list[ChatCompletionMessageParam] = []
- if request.instructions:
- messages.append(
- {
- "role": "system",
- "content": request.instructions,
- }
- )
-
- # Prepend the conversation history.
- if prev_response is not None:
- # Add the previous messages.
- prev_msg = self.msg_store[prev_response.id]
- messages.extend(prev_msg)
-
- # Add the previous output.
- for output_item in prev_response.output:
- # NOTE: We skip the reasoning output.
- if isinstance(output_item, ResponseOutputMessage):
- for content in output_item.content:
- messages.append(
- {
- "role": "assistant",
- "content": content.text,
- }
- )
-
- # Append the new input.
- # Responses API supports simple text inputs without chat format.
- if isinstance(request.input, str):
- messages.append({"role": "user", "content": request.input})
- else:
- for item in request.input:
- messages.append(construct_chat_message_with_tool_call(item))
- return messages
-
def _construct_harmony_system_input_message(
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str]
) -> OpenAIHarmonyMessage:
diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py
index b9b9b1ab30ad8..3dece07748cc4 100644
--- a/vllm/entrypoints/openai/speech_to_text.py
+++ b/vllm/entrypoints/openai/speech_to_text.py
@@ -201,10 +201,10 @@ class OpenAISpeechToText(OpenAIServing):
self.engine_client.generate(
prompt,
sampling_params,
- request_id,
+ f"{request_id}_{i}",
lora_request=lora_request,
)
- for prompt in prompts
+ for i, prompt in enumerate(prompts)
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
index 02fc9b8a4d34e..e1fe6e90dfd0b 100644
--- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
@@ -9,6 +9,7 @@ import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
+import vllm.envs as envs
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser):
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
0
]
- # Updated regex to match multiple JSONs separated by semicolons
- # This pattern is more robust and can handle nested JSON objects
- self.tool_call_regex = re.compile(
- r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
- re.DOTALL,
- )
+ # Simple regex to find opening braces - we'll use JSON decoder for parsing
+ # This handles arbitrary nesting depth correctly
+ self.tool_call_start_regex = re.compile(r"\{")
+ self.json_decoder = json.JSONDecoder()
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
@@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser):
tools_called=False, tool_calls=[], content=model_output
)
- # Find JSON object(s) in the text using regex
- match = self.tool_call_regex.search(model_output)
- if not match:
+ # Keep track of the end index of the last parsed JSON object
+ # so we don't parse inner brackets
+ end_index = -1
+ tool_calls: list[ToolCall] = []
+
+ try:
+ for match in self.tool_call_start_regex.finditer(
+ model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
+ ):
+ start_index = match.start()
+ # Skip if this brace is inside a previously parsed JSON object
+ if start_index <= end_index:
+ continue
+
+ try:
+ obj, json_end_index = self.json_decoder.raw_decode(
+ model_output[start_index:]
+ )
+ end_index = start_index + json_end_index
+
+ # raise KeyError if missing
+ name = obj["name"]
+ arguments_or_params = (
+ obj["arguments"] if "arguments" in obj else obj["parameters"]
+ )
+
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=name,
+ # function call args are JSON but as a string
+ arguments=json.dumps(
+ arguments_or_params, ensure_ascii=False
+ ),
+ ),
+ )
+ )
+ except KeyError as e:
+ # Missing required key
+ missing_key = str(e).strip("'\"")
+ logger.exception(
+ "Couldn't extract tool call from JSON response. "
+ "Required key '%s' not present. "
+ "Returning output in content with empty tool calls.",
+ missing_key,
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ except Exception:
+ # Any other error during parsing
+ logger.exception(
+ "Error in extracting tool call from response. "
+ "Returning output in content with empty tool calls"
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ except TimeoutError:
+ logger.warning("Regex timeout occurred when matching tool call pattern.")
+ logger.debug(
+ "Regex timeout occurred when matching user input: %s", model_output
+ )
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
- try:
- json_str = match.group(0)
- # Split by semicolon and strip whitespace
- json_objects = [obj.strip() for obj in json_str.split(";")]
-
- tool_calls: list[ToolCall] = []
- for json_obj in json_objects:
- if not json_obj: # Skip empty strings
- continue
- obj = json.loads(json_obj)
- tool_calls.append(
- ToolCall(
- type="function",
- function=FunctionCall(
- name=obj["name"],
- # function call args are JSON but as a string
- arguments=json.dumps(
- obj["arguments"]
- if "arguments" in obj
- else obj["parameters"],
- ensure_ascii=False,
- ),
- ),
- )
- )
-
+ # If we have valid tool calls, return them normally
+ if tool_calls:
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)
- except Exception:
- logger.exception("Error in extracting tool call from response.")
- # return information to just treat the tool call as regular JSON
- return ExtractedToolCallInformation(
- tools_called=False, tool_calls=[], content=model_output
- )
+ # No valid tool calls found
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
def extract_tool_calls_streaming(
self,
diff --git a/vllm/entrypoints/openai/utils.py b/vllm/entrypoints/openai/utils.py
new file mode 100644
index 0000000000000..6f37f6adff4c2
--- /dev/null
+++ b/vllm/entrypoints/openai/utils.py
@@ -0,0 +1,37 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import TypeVar
+
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+)
+
+# Used internally
+_ChatCompletionResponseChoiceT = TypeVar(
+ "_ChatCompletionResponseChoiceT",
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+)
+
+
+def maybe_filter_parallel_tool_calls(
+ choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
+) -> _ChatCompletionResponseChoiceT:
+ """Filter to first tool call only when parallel_tool_calls is False."""
+
+ if request.parallel_tool_calls:
+ return choice
+
+ if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
+ choice.message.tool_calls = choice.message.tool_calls[:1]
+ elif (
+ isinstance(choice, ChatCompletionResponseStreamChoice)
+ and choice.delta.tool_calls
+ ):
+ choice.delta.tool_calls = [
+ tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
+ ]
+
+ return choice
diff --git a/vllm/entrypoints/responses_utils.py b/vllm/entrypoints/responses_utils.py
index 912e8a690573d..07abb80ebc9e3 100644
--- a/vllm/entrypoints/responses_utils.py
+++ b/vllm/entrypoints/responses_utils.py
@@ -9,7 +9,11 @@ from openai.types.chat import (
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as FunctionCallTool,
)
-from openai.types.responses import ResponseFunctionToolCall
+from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
+from openai.types.responses.response_function_tool_call_output_item import (
+ ResponseFunctionToolCallOutputItem,
+)
+from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
@@ -20,6 +24,49 @@ from vllm.entrypoints.openai.protocol import (
)
+def construct_input_messages(
+ *,
+ request_instructions: str | None = None,
+ request_input: str | list[ResponseInputOutputItem],
+ prev_msg: list[ChatCompletionMessageParam] | None = None,
+ prev_response_output: list[ResponseOutputItem] | None = None,
+):
+ messages: list[ChatCompletionMessageParam] = []
+ if request_instructions:
+ messages.append(
+ {
+ "role": "system",
+ "content": request_instructions,
+ }
+ )
+
+ # Prepend the conversation history.
+ if prev_msg is not None:
+ # Add the previous messages.
+ messages.extend(prev_msg)
+ if prev_response_output is not None:
+ # Add the previous output.
+ for output_item in prev_response_output:
+ # NOTE: We skip the reasoning output.
+ if isinstance(output_item, ResponseOutputMessage):
+ for content in output_item.content:
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content.text,
+ }
+ )
+
+ # Append the new input.
+ # Responses API supports simple text inputs without chat format.
+ if isinstance(request_input, str):
+ messages.append({"role": "user", "content": request_input})
+ else:
+ for item in request_input:
+ messages.append(construct_chat_message_with_tool_call(item))
+ return messages
+
+
def construct_chat_message_with_tool_call(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
@@ -50,6 +97,12 @@ def construct_chat_message_with_tool_call(
"role": "assistant",
"reasoning": reasoning_content,
}
+ elif isinstance(item, ResponseFunctionToolCallOutputItem):
+ return ChatCompletionToolMessageParam(
+ role="tool",
+ content=item.output,
+ tool_call_id=item.call_id,
+ )
elif item.get("type") == "function_call_output":
# Append the function call output as a tool message.
return ChatCompletionToolMessageParam(
diff --git a/vllm/envs.py b/vllm/envs.py
index 77a705b789388..04b61c79600e0 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -647,7 +647,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Example options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
- # - "XFORMERS": use XFormers
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index 25fb7181a8f29..173d366267e87 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -5,19 +5,17 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, NamedTuple
+from typing import Any, NamedTuple
import torch
import vllm.envs as envs
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
-
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@@ -35,23 +33,27 @@ class BatchDescriptor(NamedTuple):
"""
num_tokens: int
- uniform_decode: bool = False
+ num_reqs: int | None = None
"""
- False can also be used for an uniform decode batch to dispatch to the
- cudagraph supporting non-uniform batches.
+ Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
+ the cudagraphs can handle any number of requests.
+ """
+ uniform: bool = False
+ """
+ True if all the requests in the batch have the same number of tokens.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""
- @property
- def non_uniform(self) -> "BatchDescriptor":
+ def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
- Return a non-uniform version of current batch descriptor.
+ Return a relaxed version of current batch descriptor that is still compatible
+ with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
- self.num_tokens, uniform_decode=False, has_lora=self.has_lora
+ self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
)
@@ -153,7 +155,7 @@ class DPMetadata:
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
- Context mamager for setting self.local_sizes. Same as self.chunked_sizes
+ Context manager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
@@ -191,7 +193,7 @@ class ForwardContext:
for each microbatch.
Set dynamically for each forward pass
"""
- attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]]
+ attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
diff --git a/vllm/logger.py b/vllm/logger.py
index 772e36497b45e..ad3123c0f0149 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -7,7 +7,8 @@ import json
import logging
import os
import sys
-from collections.abc import Hashable
+from collections.abc import Generator, Hashable
+from contextlib import contextmanager
from functools import lru_cache, partial
from logging import Logger
from logging.config import dictConfig
@@ -212,6 +213,14 @@ def init_logger(name: str) -> _VllmLogger:
return cast(_VllmLogger, logger)
+@contextmanager
+def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
+ current_level = logging.root.manager.disable
+ logging.disable(level)
+ yield
+ logging.disable(current_level)
+
+
# The root logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py
index 3bfb88c007622..a4b8fb4d2aec5 100644
--- a/vllm/lora/layers/base.py
+++ b/vllm/lora/layers/base.py
@@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py
index 3e21d426c304a..904025901fba7 100644
--- a/vllm/lora/layers/column_parallel_linear.py
+++ b/vllm/lora/layers/column_parallel_linear.py
@@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
@@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
@@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
@@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
@@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py
index 0eb6562bec6cd..3ad19370962ab 100644
--- a/vllm/lora/layers/fused_moe.py
+++ b/vllm/lora/layers/fused_moe.py
@@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
+from .utils import _get_lora_device
+
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
@@ -41,7 +43,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
- self.device = base_layer.w2_weight.device
+ self.device = _get_lora_device(base_layer)
self._w13_slices = 2
self._inject_lora_into_fused_moe()
@@ -399,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
+ def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
+ """
+ Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
+ """
+ if self.tp_size == 1 or not self.fully_sharded:
+ return w13_lora_a
+
+ # w13_lora_a shape (num_experts,rank,input_size)
+ current_lora_rank = w13_lora_a.shape[1]
+ assert current_lora_rank % self.tp_size == 0
+ # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
+ sliced_rank = current_lora_rank // self.tp_size
+ start_idx = self.tp_rank * sliced_rank
+ end_idx = (self.tp_rank + 1) * sliced_rank
+ return w13_lora_a[:, start_idx:end_idx, :]
+
+ def _slice_w13_b(self, w13_lora_b: torch.Tensor):
+ if self.tp_size == 1:
+ return w13_lora_b
+
+ # w13_lora_b shape (num_experts,output_size,rank)
+ shard_size = self.base_layer.intermediate_size_per_partition
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+
+ return w13_lora_b[:, start_idx:end_idx, :]
+
+ def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
+ """
+ Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
+ """
+ if self.tp_size == 1:
+ return w2_lora_a
+ # w2_lora_a shape (num_experts,rank,input_size)
+ shard_size = self.base_layer.intermediate_size_per_partition
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+
+ return w2_lora_a[:, :, start_idx:end_idx]
+
+ def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
+ """
+ Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
+ """
+ if self.tp_size == 1 or not self.fully_sharded:
+ return w2_lora_b
+ # Based on S-LoRA, we slice W2 B along the hidden_size dim.
+ # w2_lora_b shape (num_experts,output_size,rank)
+ current_lora_size = w2_lora_b.shape[1]
+
+ sliced_size = current_lora_size // self.tp_size
+ start_idx = self.tp_rank * sliced_size
+ end_idx = (self.tp_rank + 1) * sliced_size
+ return w2_lora_b[:, start_idx:end_idx, :]
+
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices):
@@ -409,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
+ #
+
def set_lora(
self,
index: int,
@@ -416,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
+ # Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
+
self.reset_lora(index)
self.adapter_enabled[index] = 1
- for eid in range(len(lora_a) // 3):
- w1_lora_a = lora_a[eid * 3]
- w2_lora_a = lora_a[eid * 3 + 1]
- w3_lora_a = lora_a[eid * 3 + 2]
- w1_lora_b = lora_b[eid * 3]
- w2_lora_b = lora_b[eid * 3 + 1]
- w3_lora_b = lora_b[eid * 3 + 2]
- # Handle the case of adding LoRA to only a subset of experts
- if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
- continue
+ num_experts = self.w13_lora_a_stacked[0].shape[1]
- if self.tp_size > 1:
- shard_size = self.base_layer.intermediate_size_per_partition
- start_idx = self.tp_rank * shard_size
- end_idx = (self.tp_rank + 1) * shard_size
+ w1_lora_a, w2_lora_a, w3_lora_a = lora_a
+ w1_lora_b, w2_lora_b, w3_lora_b = lora_b
+ assert (
+ num_experts
+ == w1_lora_a.shape[0]
+ == w2_lora_a.shape[0]
+ == w3_lora_a.shape[0]
+ )
- w1_lora_b = w1_lora_b[start_idx:end_idx, :]
- w3_lora_b = w3_lora_b[start_idx:end_idx, :]
- w2_lora_a = w2_lora_a[:, start_idx:end_idx]
+ slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
+ slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
+ slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
+ slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
- if self.fully_sharded:
- # Based on S-LoRA, we slice W1 and W3 A along the rank dim,
- # and W2 B along the hidden_size dim.
- w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
- w13_start_idx = self.tp_rank * w13_shard_size
- w13_end_idx = (self.tp_rank + 1) * w13_shard_size
- w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
- w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
+ sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
+ sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
- w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
- w2_start_idx = self.tp_rank * w2_shard_size
- w2_end_idx = (self.tp_rank + 1) * w2_shard_size
- w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
- # w1 lora_a
- self.w13_lora_a_stacked[0][
- index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
- ].copy_(w1_lora_a, non_blocking=True)
- # w3 lora_a
- self.w13_lora_a_stacked[1][
- index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
- ].copy_(w3_lora_a, non_blocking=True)
+ self.w13_lora_a_stacked[0][
+ index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
+ ].copy_(slliced_w1_lora_a, non_blocking=True)
- # w1 lora_b
- self.w13_lora_b_stacked[0][
- index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
- ].copy_(w1_lora_b, non_blocking=True)
- # w3 lora_b
- self.w13_lora_b_stacked[1][
- index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
- ].copy_(w3_lora_b, non_blocking=True)
+ self.w13_lora_a_stacked[1][
+ index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
+ ].copy_(slliced_w3_lora_a, non_blocking=True)
- self.w2_lora_a_stacked[0][
- index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
- ].copy_(w2_lora_a, non_blocking=True)
+ self.w13_lora_b_stacked[0][
+ index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
+ ].copy_(slliced_w1_lora_b, non_blocking=True)
- self.w2_lora_b_stacked[0][
- index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
- ].copy_(w2_lora_b, non_blocking=True)
+ self.w13_lora_b_stacked[1][
+ index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
+ ].copy_(slliced_w3_lora_b, non_blocking=True)
+
+ self.w2_lora_a_stacked[0][
+ index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
+ ].copy_(sliced_w2_lora_a, non_blocking=True)
+
+ self.w2_lora_b_stacked[0][
+ index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
+ ].copy_(sliced_w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
@@ -504,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
- # return type(source_layer) is FusedMoE
- return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
+ # source_layer is FusedMoE or SharedFusedMoE
+ return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
@@ -553,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
+
+ assert isinstance(model_config, PretrainedConfig)
+ self._base_model = model_config.architectures[0]
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
@@ -563,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
- def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
- if self.tp_size == 1 or not self.fully_sharded:
- return w13_lora_a
-
- # w13_lora_a shape (num_experts,rank,input_size)
- current_lora_rank = w13_lora_a.shape[1]
- assert current_lora_rank % self.tp_size == 0
-
- sliced_rank = current_lora_rank // self.tp_size
- start_idx = self.tp_rank * sliced_rank
- end_idx = (self.tp_rank + 1) * sliced_rank
- return w13_lora_a[:, start_idx:end_idx, :]
-
- def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
+ def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
@@ -584,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
- if is_interleave:
+ # HACK: Currently, only GPT-OSS is in interleaved order
+ if self._base_model == "GptOssForCausalLM":
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :]
@@ -604,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
- def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
- if self.tp_size == 1:
- return w2_lora_a
- # w2_lora_a shape (num_experts,rank,input_size)
- shard_size = self.base_layer.intermediate_size_per_partition
- start_idx = self.tp_rank * shard_size
- end_idx = (self.tp_rank + 1) * shard_size
-
- return w2_lora_a[:, :, start_idx:end_idx]
-
- def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
- if self.tp_size == 1 or not self.fully_sharded:
- return w2_lora_b
- # Based on S-LoRA, we slice W2 B along the hidden_size dim.
- # w2_lora_b shape (num_experts,output_size,rank)
- current_lora_size = w2_lora_b.shape[1]
-
- sliced_size = current_lora_size // self.tp_size
- start_idx = self.tp_rank * sliced_size
- end_idx = (self.tp_rank + 1) * sliced_size
- return w2_lora_b[:, start_idx:end_idx, :]
-
def set_lora(
self,
index: int,
@@ -656,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
- sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
+ sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
@@ -709,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
-
- return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
+ # source_layer is FusedMoE or SharedFusedMoE
+ return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1
diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py
index c01984db4e64c..01515f6136371 100644
--- a/vllm/lora/layers/logits_processor.py
+++ b/vllm/lora/layers/logits_processor.py
@@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
# Special handling for the LogitsProcessor.
return False
diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py
index 243736c4ebc65..62bac546ccd1a 100644
--- a/vllm/lora/layers/replicated_linear.py
+++ b/vllm/lora/layers/replicated_linear.py
@@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ReplicatedLinear
diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py
index 95517b1aee263..958aa6af36746 100644
--- a/vllm/lora/layers/row_parallel_linear.py
+++ b/vllm/lora/layers/row_parallel_linear.py
@@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is RowParallelLinear
@@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py
index 2da90f180ee74..74403240f6cc2 100644
--- a/vllm/lora/layers/utils.py
+++ b/vllm/lora/layers/utils.py
@@ -33,6 +33,15 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
+ # MoE layer
+ elif hasattr(base_layer, "w2_weight"):
+ return base_layer.w2_weight.device
+ # MoE Compressed Tensor
+ elif hasattr(base_layer, "w2_weight_packed"):
+ return base_layer.w2_weight_packed.device
+ # MoE GPTQ/AWQ/GGUF
+ elif hasattr(base_layer, "w2_qweight"):
+ return base_layer.w2_qweight.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py
index c87ca9e24dece..4c1550d09e5e2 100644
--- a/vllm/lora/layers/vocal_parallel_embedding.py
+++ b/vllm/lora/layers/vocal_parallel_embedding.py
@@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
- model_config: PretrainedConfig | None,
+ model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is VocabParallelEmbedding
diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py
index f0d8e22194050..15c4a1be63eeb 100644
--- a/vllm/lora/lora_weights.py
+++ b/vllm/lora/lora_weights.py
@@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
)
return obj
+ @classmethod
+ def pack_moe(
+ cls, loras: GenericSequence[Optional["LoRALayerWeights"]], module_name: str
+ ) -> "PackedLoRALayerWeights":
+ """Pack a list of LoRAs into a single LoRA.
+
+ If LoRA is None, it signifies that the submodule does not have a LoRA.
+ """
+
+ first_lora = next(lora for lora in loras if lora is not None)
+ assert first_lora is not None
+ rank = first_lora.rank
+ lora_alpha = first_lora.lora_alpha
+ assert len(loras) % 3 == 0
+ w1_lora_a_lst = []
+ w2_lora_a_lst = []
+ w3_lora_a_lst = []
+ w1_lora_b_lst = []
+ w2_lora_b_lst = []
+ w3_lora_b_lst = []
+ # TODO: Consider the case where some experts don't have LoRA added.
+ for eid in range(len(loras) // 3):
+ w1_lora = loras[eid * 3]
+ w2_lora = loras[eid * 3 + 1]
+ w3_lora = loras[eid * 3 + 2]
+ assert w1_lora is not None
+ assert w2_lora is not None
+ assert w3_lora is not None
+
+ w1_lora_a_lst.append(w1_lora.lora_a)
+ w2_lora_a_lst.append(w2_lora.lora_a)
+ w3_lora_a_lst.append(w3_lora.lora_a)
+
+ w1_lora_b_lst.append(w1_lora.lora_b)
+ w2_lora_b_lst.append(w2_lora.lora_b)
+ w3_lora_b_lst.append(w3_lora.lora_b)
+
+ w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
+ w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
+ w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
+ w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
+ w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
+ w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
+
+ obj = cls(
+ module_name,
+ rank,
+ [lora_alpha, lora_alpha, lora_alpha],
+ [w1_lora_a, w2_lora_a, w3_lora_a],
+ [w1_lora_b, w2_lora_b, w3_lora_b],
+ )
+ return obj
+
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 636f062feb7b0..4caaf0e117cc4 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
-from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
+from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
@@ -151,16 +151,13 @@ class LoRAModel:
if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
- for lora in loras.values():
- lora.optimize()
-
return cls(lora_model_id, peft_helper.r, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
- expected_lora_modules: list[str],
+ expected_lora_modules: set[str],
peft_helper: PEFTHelper,
*,
lora_model_id: int | None = None,
@@ -190,10 +187,7 @@ class LoRAModel:
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
- # new_embeddings_tensor_path = os.path.join(
- # lora_dir, "new_embeddings.safetensors"
- # )
- # new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
+
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
@@ -201,18 +195,19 @@ class LoRAModel:
for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module):
continue
- module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
- # Handle FSDP file format where experts.base_layer is the
+ # Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
+ module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights
if ".experts" in module_name:
- if not any(
- module_name.endswith(ele) for ele in expected_lora_modules
- ):
+ expert_idx = module_name.find(".experts")
+ expert_suffix = module_name[expert_idx + 1 :]
+ if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name)
- elif module_name.split(".")[-1] not in expected_lora_modules:
+
+ elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
@@ -358,9 +353,7 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
- self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
- self.model, "is_3d_moe_weight"
- )
+ self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
self._create_lora_modules()
self.model.lora_manager = self
@@ -411,7 +404,7 @@ class LoRAModelManager:
continue
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
- if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
+ if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle PEFT file format where experts.base_layer is the
@@ -679,7 +672,10 @@ class LoRAModelManager:
"cpu",
)
subloras.append(lora)
- lora = PackedLoRALayerWeights.pack(subloras)
+ if module.__class__.__name__ == "FusedMoEWithLoRA":
+ lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
+ else:
+ lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
@@ -739,13 +735,21 @@ class LoRAModelManager:
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
- lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
- replacement_loras
- )
+ if module_name.endswith(".experts"):
+ lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
+ replacement_loras, module_name
+ )
+ else:
+ lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
+ replacement_loras
+ )
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
+ for lora in lora_model.loras.values():
+ lora.optimize()
+
def _get_lora_layer_weights(
self, lora_model: LoRAModel, module_name: str
) -> LoRALayerWeights | None:
diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py
index ce38751e4b6a7..47c42b095534a 100644
--- a/vllm/lora/punica_wrapper/punica_base.py
+++ b/vllm/lora/punica_wrapper/punica_base.py
@@ -173,7 +173,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
vocab_size: int,
):
# NOTE We have remove lora extra vocab support for now. So we set
- # extra_vocab_size alwayzs to 0, and extra_vocab_size will be removed.
+ # extra_vocab_size always to 0, and extra_vocab_size will be removed.
extra_vocab_size = 0
(
diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py
index 12524994d4968..47484b2b984df 100644
--- a/vllm/lora/utils.py
+++ b/vllm/lora/utils.py
@@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights
- input_embedding_subfix = ".embed_tokens.base_layer.weight"
- output_embedding_subfix = ".lm_head.base_layer.weight"
-
- return name.endswith(input_embedding_subfix) or name.endswith(
- output_embedding_subfix
+ embedding_suffixes = (
+ ".embed_tokens.base_layer.weight",
+ ".lm_head.base_layer.weight",
)
+ return name.endswith(embedding_suffixes)
def is_regex_target_modules(
- load_modules: str | list[str], expected_lora_modules: list[str]
+ load_modules: str | list[str], expected_lora_modules: set[str]
) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
@@ -195,8 +194,8 @@ def is_regex_target_modules(
except re.error:
return False
- def is_subset(sub_list, full_list):
- return set(sub_list).issubset(set(full_list))
+ def is_subset(sub_list, full_set):
+ return set(sub_list).issubset(full_set)
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
@@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
- if not hasattr(model, "is_3d_moe_weight"):
+ if not model.is_3d_moe_weight:
# 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping["experts"] = [
weight_name.rstrip(".")
diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py
index 4cc201a6414f1..d9a03f0500497 100644
--- a/vllm/lora/worker_manager.py
+++ b/vllm/lora/worker_manager.py
@@ -88,15 +88,15 @@ class WorkerLoRAManager:
try:
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
- expected_lora_modules: list[str] = []
+ expected_lora_lst: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
- expected_lora_modules.extend(packed_modules_mapping[module])
+ expected_lora_lst.extend(packed_modules_mapping[module])
else:
- expected_lora_modules.append(module)
+ expected_lora_lst.append(module)
if module == "experts":
- expected_lora_modules.append(module)
- expected_lora_modules = list(set(expected_lora_modules))
+ expected_lora_lst.append(module)
+ expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(
diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py
index ffbef470b1868..a60cf787135c0 100644
--- a/vllm/model_executor/layers/attention_layer_base.py
+++ b/vllm/model_executor/layers/attention_layer_base.py
@@ -3,14 +3,11 @@
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING
+from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
-
class AttentionLayerBase(ABC):
"""
@@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
"""
@abstractmethod
- def get_attn_backend(self) -> type["AttentionBackend"]:
+ def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer."""
pass
diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py
index 8b33727f05fbc..4154122636dcf 100644
--- a/vllm/model_executor/layers/batch_invariant.py
+++ b/vllm/model_executor/layers/batch_invariant.py
@@ -215,6 +215,139 @@ def matmul_persistent(
return c
+@triton.jit
+def bmm_kernel(
+ a_ptr, # (*, ) pointer to A, (B, M, K)
+ b_ptr, # (*, ) pointer to B, (B, K, N)
+ c_ptr, # (*, ) pointer to C, (B, M, N)
+ B, # int, batch size
+ M, # int, output rows
+ N, # int, output cols
+ K, # int, reduction dim
+ stride_ab,
+ stride_am,
+ stride_ak,
+ stride_bb,
+ stride_bk,
+ stride_bn,
+ stride_cb,
+ stride_cm,
+ stride_cn,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ A_LARGE: tl.constexpr,
+ B_LARGE: tl.constexpr,
+ C_LARGE: tl.constexpr,
+):
+ """Batched GEMM: (B, M, K) x (B, K, N) -> (B, M, N)
+
+ Each program computes one (batch_idx, tile_m, tile_n) tile, accumulating
+ along K in a fixed order to preserve batch invariance.
+ """
+ pid_b = tl.program_id(0)
+ pid = tl.program_id(1)
+
+ if pid_b >= B:
+ return
+
+ # number of tiles along M / N
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+
+ pid_m = pid // num_pid_n
+ pid_n = pid % num_pid_n
+
+ if pid_m >= num_pid_m or pid_n >= num_pid_n:
+ return
+
+ # offs_m / offs_n: raw global row/col indices for this tile
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ # masks for valid logical rows/cols within (M, N)
+ mask_m = offs_m < M # [BLOCK_SIZE_M]
+ mask_n = offs_n < N # [BLOCK_SIZE_N]
+
+ if A_LARGE or B_LARGE or C_LARGE:
+ offs_m = offs_m.to(tl.int64)
+ offs_n = offs_n.to(tl.int64)
+
+ offs_m = tl.where(mask_m, offs_m, 0)
+ offs_n = tl.where(mask_n, offs_n, 0)
+
+ # hint for triton contiguous memory
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M)
+ offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
+
+ # base pointers for current batch, shape-wise:
+ # a_batch_ptr points to A[pid_b, 0, 0]
+ # b_batch_ptr points to B[pid_b, 0, 0]
+ # c_batch_ptr points to C[pid_b, 0, 0]
+ a_batch_ptr = a_ptr + pid_b * stride_ab
+ b_batch_ptr = b_ptr + pid_b * stride_bb
+ c_batch_ptr = c_ptr + pid_b * stride_cb
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ # number of K-blocks this tile iterates over
+ k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
+ offs_k_mask = tl.arange(0, BLOCK_SIZE_K)
+
+ for ki in range(k_tiles):
+ if A_LARGE or B_LARGE:
+ # offs_k: [BLOCK_SIZE_K], global K indices
+ offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
+ else:
+ offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
+
+ # a_ptrs: [BLOCK_SIZE_M, BLOCK_SIZE_K]
+ # element (i, j) points to A[pid_b, offs_m[i], offs_k[j]]
+ a_ptrs = a_batch_ptr + (
+ offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
+ )
+ # b_ptrs: [BLOCK_SIZE_K, BLOCK_SIZE_N]
+ # element (i, j) points to B[pid_b, offs_k[i], offs_n[j]]
+ b_ptrs = b_batch_ptr + (
+ offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
+ )
+
+ # valid K lanes for this block
+ k_valid = offs_k_mask < (K - ki * BLOCK_SIZE_K)
+ # A mask within (M, K): [BLOCK_SIZE_M, BLOCK_SIZE_K]
+ a_mask = mask_m[:, None] & k_valid[None, :]
+ # B mask within (K, N): [BLOCK_SIZE_K, BLOCK_SIZE_N]
+ b_mask = k_valid[:, None] & mask_n[None, :]
+
+ # a: [BLOCK_SIZE_M, BLOCK_SIZE_K] from A[offs_m, offs_k]
+ a = tl.load(
+ a_ptrs,
+ mask=a_mask,
+ other=0.0,
+ )
+ # b: [BLOCK_SIZE_K, BLOCK_SIZE_N] from B[offs_k, offs_n]
+ b = tl.load(
+ b_ptrs,
+ mask=b_mask,
+ other=0.0,
+ )
+ accumulator = tl.dot(a, b, accumulator)
+
+ # c_m / c_n: [BLOCK_SIZE_M] / [BLOCK_SIZE_N], row/col indices for C
+ c_m = offs_m
+ c_n = offs_n
+ if C_LARGE:
+ c_m = c_m.to(tl.int64)
+ c_n = c_n.to(tl.int64)
+
+ # c_ptrs: [BLOCK_SIZE_M, BLOCK_SIZE_N]
+ # element (i, j) points to C[pid_b, c_m[i], c_n[j]]
+ c_ptrs = c_batch_ptr + stride_cm * c_m[:, None] + stride_cn * c_n[None, :]
+ # mask out elements that fall outside logical (M, N) range
+ c_mask = mask_m[:, None] & mask_n[None, :]
+ # cast FP32 accumulator back to original dtype of C
+ c = accumulator.to(c_ptr.dtype.element_ty)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+
@triton.jit
def _log_softmax_kernel(
input_ptr,
@@ -526,23 +659,91 @@ def matmul_batch_invariant(a, b, *, out=None):
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
- # Process each batch separately with our persistent kernel
- if a.ndim == 3 and b.ndim == 3:
- results = []
- for i in range(a.shape[0]):
- results.append(matmul_persistent(a[i], b[i]))
- result = torch.stack(results, dim=0)
-
- if out is not None:
- out.copy_(result)
- return out
- return result
- else:
+ if not (a.ndim == 3 and b.ndim == 3):
raise ValueError(
f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}"
)
+ if a.shape[0] != b.shape[0]:
+ raise ValueError(
+ f"Batch dimensions of tensors must match, "
+ f"but got {a.shape[0]} and {b.shape[0]}."
+ )
+ if a.shape[2] != b.shape[1]:
+ raise ValueError(
+ f"Incompatible inner dimensions for matmul: got {a.shape} and {b.shape}."
+ )
+ if a.dtype != b.dtype:
+ raise ValueError(f"Incompatible dtypes: got {a.dtype} and {b.dtype}.")
+
+ B, M, K = a.shape
+ _, _, N = b.shape
+ dtype = a.dtype
+
+ if out is None:
+ c = torch.empty((B, M, N), device=a.device, dtype=dtype)
+ else:
+ assert out.shape == (B, M, N), "out tensor has incorrect shape"
+ assert out.dtype == dtype and out.device == a.device, "out tensor mismatch"
+ c = out
+
+ configs = {
+ torch.bfloat16: {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "num_stages": 3,
+ "num_warps": 8,
+ },
+ torch.float16: {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "num_stages": 3,
+ "num_warps": 8,
+ },
+ torch.float32: {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "num_stages": 3,
+ "num_warps": 8,
+ },
+ }
+
+ cfg = configs[dtype]
+ # grid = (B, num_tiles_per_matrix)
+ grid = (
+ B,
+ triton.cdiv(M, cfg["BLOCK_SIZE_M"]) * triton.cdiv(N, cfg["BLOCK_SIZE_N"]),
+ )
+
+ bmm_kernel[grid](
+ a,
+ b,
+ c,
+ B,
+ M,
+ N,
+ K,
+ a.stride(0),
+ a.stride(1),
+ a.stride(2),
+ b.stride(0),
+ b.stride(1),
+ b.stride(2),
+ c.stride(0),
+ c.stride(1),
+ c.stride(2),
+ A_LARGE=a.numel() > 2**31,
+ B_LARGE=b.numel() > 2**31,
+ C_LARGE=c.numel() > 2**31,
+ **cfg,
+ )
+
+ return c
+
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
@@ -812,19 +1013,19 @@ def override_envs_for_invariance():
# "TRITON_MLA",
]
if curr_attn_backend not in supported_backends:
- warning = (
- "Forcibly updating attention backend to"
- f" {supported_backends[0]} for batch_invariant. "
- f" Supported backends: {supported_backends}."
+ error = (
+ "VLLM batch_invariant mode requires an attention backend in "
+ f"{supported_backends}, but got '{curr_attn_backend}'. "
+ "Please set the 'VLLM_ATTENTION_BACKEND' environment variable "
+ "to one of the supported backends before enabling batch_invariant."
)
- logger.warning_once(warning)
- os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
+ raise RuntimeError(error)
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
- logger.warning_once(warning)
+ logger.warning_once(warning, scope="local")
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py
index 53d98d0650b43..669abcb3d6ff1 100644
--- a/vllm/model_executor/layers/fused_moe/__init__.py
+++ b/vllm/model_executor/layers/fused_moe/__init__.py
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
+ UnquantizedFusedMoEMethod,
+)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
@@ -41,6 +44,7 @@ __all__ = [
"FusedMoE",
"FusedMoEConfig",
"FusedMoEMethodBase",
+ "UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
index 572307052b489..659a2d4ee5b39 100644
--- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
@@ -6,22 +6,7 @@ import torch
from torch.nn import functional as F
from vllm import _custom_ops as ops
-
-
-def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
- d = x.shape[-1] // 2
- return F.silu(x[..., :d]) * x[..., d:]
-
-
-def swigluoai_and_mul(
- x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
-) -> torch.Tensor:
- d = x.shape[-1] // 2
- gate, up = x[..., :d], x[..., d:]
- gate = gate.clamp(max=limit)
- up = up.clamp(min=-limit, max=limit)
- glu = gate * torch.sigmoid(alpha * gate)
- return (up + 1) * glu
+from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
def grouped_topk(
@@ -227,6 +212,11 @@ class CPUFusedMOE:
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
+ self.act_to_impl = {
+ "silu": SiluAndMul(),
+ "swigluoai": SwigluOAIAndMul(),
+ }
+
def __call__(
self,
layer: torch.nn.Module,
@@ -246,7 +236,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
- assert activation in {"silu", "swigluoai"}, f"{activation} is not supported."
+ assert activation in self.act_to_impl, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -283,10 +273,7 @@ class CPUFusedMOE:
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
- if activation == "swigluoai":
- gate_up = swigluoai_and_mul(gate_up)
- else:
- gate_up = silu_and_mul(gate_up)
+ gate_up = self.act_to_impl[activation].forward_native(gate_up)
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index 073e90a4e6808..ef7090c349fc6 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -90,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def allow_inplace(self) -> bool:
return False
+ @property
+ def method_name(self) -> str:
+ return self.__class__.__name__
+
@abstractmethod
def apply(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
index c6dc95acdb636..c23c41df226f0 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
@@ -66,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
+ @property
+ def method_name(self) -> str:
+ return self.old_quant_method.method_name
+
def create_weights(
self,
layer: torch.nn.Module,
@@ -84,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -105,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- # Is getattr needed?
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
- if enable_eplb:
- if self.supports_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
- else:
- raise NotImplementedError(
- "EPLB is not supported for "
- f"{self.old_quant_method.__class__.__name__}."
- )
-
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
)
result = self.fused_experts(
@@ -156,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
expert_map=None if self.disable_expert_map else expert_map,
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
index badedfc54c382..128507639fdfd 100644
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
@@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output, topk, sm_first=not renormalize
)
+ output = torch.empty_like(hidden_states)
+
return triton_kernel_fused_experts(
- None,
+ output,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
+ topk=topk,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
@@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
+ topk: int,
activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
@@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
+ intermediate_cache: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if quant_config is None:
@@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
+ assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
+ batch_dim = 1
+ M, K = hidden_states.shape[-2:]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
+ if intermediate_cache is None:
+ intermediate_cache = torch.empty(
+ (batch_dim, M * topk, N // 2),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ # Add batch_dim to output buffer because matmul_ogs expects 3D output
+ intermediate_cache = _resize_cache(
+ intermediate_cache, (batch_dim, M * topk, N // 2)
+ )
+ output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
+
act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
@@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
)
gammas = routing_data.gate_scal if routing_data else None
- intermediate_cache1 = matmul_ogs(
+ matmul_ogs(
hidden_states,
w1,
quant_config.w1_bias,
@@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
+ y=intermediate_cache,
)
- intermediate_cache3 = matmul_ogs(
- intermediate_cache1,
+ matmul_ogs(
+ intermediate_cache.view(M * topk, N // 2),
w2,
quant_config.w2_bias,
routing_data,
@@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
- return intermediate_cache3
+ output_tensor = output_tensor.view(M, K)
+ return output_tensor
def make_routing_data(
@@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return True
+ def moe_problem_size(
+ self,
+ a1: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_ids: torch.Tensor,
+ ) -> tuple[int, int, int, int, int]:
+ """
+ Extract the MoE problem size from the given tensor arguments:
+ - a: The hidden states, input to the MoE layer.
+ - w1: The first set of expert weights.
+ - w2: The second set of expert weights.
+ - topk_ids: The topk ids.
+ Note: extracting the problem shape from the weight and activation
+ tensors is not obvious. It needs to be done this way specifically
+ due to subtle issues with particular kernels, e.g. the int4 kernels
+ divide the trailing dimension by two, so it's not "correct" to
+ extract N or K from the trailing dimension of w1 or w2. Similarly,
+ some kernels transpose the weights, so this needs to be kept in mind.
+ Note: This implementation covers most cases. However, if experts
+ require a specialized implementation, like MarlinExperts, they are free
+ to override this function.
+ """
+ assert w1.dim() == 3 and w2.dim() == 3
+ E, _, N = w1.size()
+ K = a1.size(-1)
+
+ assert a1.dim() == 2
+ assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
+ M = a1.size(0)
+
+ assert topk_ids.dim() == 2
+ topk = topk_ids.size(1)
+
+ return E, M, N, K, topk
+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()
@@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
- workspace1 = (M, K)
- workspace2 = (0, 0)
+ workspace1 = (0, 0)
+ workspace2 = (M * topk, N // 2)
output = (M, K)
return (workspace1, workspace2, output)
@@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids, topk_weights, local_num_experts
)
- experts_output = triton_kernel_fused_experts(
- None,
+ topk = topk_ids.size(1)
+ triton_kernel_fused_experts(
+ output,
hidden_states,
w1,
w2,
routing_data,
gather_indx,
scatter_indx,
+ topk=topk,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
global_num_experts=local_num_experts,
expert_map=None, # applied already
+ intermediate_cache=workspace2,
a1q_scale=a1q_scale,
)
-
- output.copy_(experts_output, non_blocking=True)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 6619b64b2bbc0..0ef3130b26333 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -1510,30 +1510,11 @@ class FusedMoE(CustomOp):
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
- @staticmethod
def select_experts(
+ self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
- top_k: int,
- use_grouped_topk: bool,
- renormalize: bool,
- topk_group: int | None = None,
- num_expert_group: int | None = None,
- custom_routing_function: Callable | None = None,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0,
- e_score_correction_bias: torch.Tensor | None = None,
- indices_type: torch.dtype | None = None,
- enable_eplb: bool = False,
- expert_map: torch.Tensor | None = None,
- expert_load_view: torch.Tensor | None = None,
- logical_to_physical_map: torch.Tensor | None = None,
- logical_replica_count: torch.Tensor | None = None,
- global_num_experts: int | None = None,
- zero_expert_num: int | None = None,
- zero_expert_type: str | None = None,
- num_fused_shared_experts: int = 0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
@@ -1552,6 +1533,27 @@ class FusedMoE(CustomOp):
fused_topk_bias,
)
+ if self.enable_eplb:
+ if self.quant_method.supports_eplb:
+ if self.expert_load_view is None:
+ raise ValueError(
+ "enable_eplb=True requiere expert_load_view != None"
+ )
+ if self.logical_to_physical_map is None:
+ raise ValueError(
+ "enable_eplb=True requiere logical_to_physical_map != None"
+ )
+ if self.logical_replica_count is None:
+ raise ValueError(
+ "enable_eplb=True requiere logical_replica_count != None"
+ )
+ else:
+ raise NotImplementedError(
+ f"EPLB is not supported for {self.quant_method.method_name}."
+ )
+
+ indices_type = self.quant_method.topk_indices_dtype
+
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
@@ -1559,20 +1561,20 @@ class FusedMoE(CustomOp):
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
- top_k=top_k,
+ top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
- elif use_grouped_topk:
- assert topk_group is not None
- assert num_expert_group is not None
+ elif self.use_grouped_topk:
+ assert self.topk_group is not None
+ assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
- assert num_fused_shared_experts == 0
+ assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
- num_fused_shared_experts=num_fused_shared_experts,
+ num_fused_shared_experts=self.num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
@@ -1580,50 +1582,46 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
+ topk=self.top_k,
+ renormalize=self.renormalize,
+ num_expert_group=self.num_expert_group,
+ topk_group=self.topk_group,
+ scoring_func=self.scoring_func,
+ routed_scaling_factor=self.routed_scaling_factor,
+ e_score_correction_bias=self.e_score_correction_bias,
)
- elif e_score_correction_bias is not None:
+ elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
- e_score_correction_bias=e_score_correction_bias.data,
- topk=top_k,
- renormalize=renormalize,
+ e_score_correction_bias=self.e_score_correction_bias.data,
+ topk=self.top_k,
+ renormalize=self.renormalize,
)
- if routed_scaling_factor != 1.0:
- topk_weights *= routed_scaling_factor
- elif custom_routing_function is None:
+ if self.routed_scaling_factor != 1.0:
+ topk_weights *= self.routed_scaling_factor
+ elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
+ topk=self.top_k,
+ renormalize=self.renormalize,
indices_type=indices_type,
)
else:
- topk_weights, topk_ids = custom_routing_function(
+ topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
+ topk=self.top_k,
+ renormalize=self.renormalize,
)
- if enable_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
-
+ if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
+ expert_load_view=self.expert_load_view,
+ logical_to_physical_map=self.logical_to_physical_map,
+ logical_replica_count=self.logical_replica_count,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
@@ -1633,16 +1631,16 @@ class FusedMoE(CustomOp):
# Compute zero expert result if needed
if (
- zero_expert_num is not None
- and zero_expert_num > 0
- and zero_expert_type is not None
- and global_num_experts is not None
+ self.zero_expert_num is not None
+ and self.zero_expert_num > 0
+ and self.zero_expert_type is not None
+ and self.global_num_experts is not None
):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
- num_experts=global_num_experts,
- zero_expert_type=zero_expert_type,
+ num_experts=self.global_num_experts,
+ zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states,
)
else:
diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
index 6ec8b33ed9309..9aaeec4f98a61 100644
--- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
@@ -38,7 +38,6 @@ class SharedFusedMoE(FusedMoE):
# TODO(wentao): find the root cause and remove this condition
self.enable_eplb
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
- or self.use_marlin_kernels
)
and self._shared_experts is not None
)
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index 63b0e6f573d65..48e5a8907f926 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -331,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -352,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
- num_fused_shared_experts=layer.num_fused_shared_experts,
)
if self.rocm_aiter_moe_enabled:
@@ -415,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
@@ -425,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -474,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -515,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py
index aa919d6fdc35c..74f4383e9c238 100644
--- a/vllm/model_executor/layers/mamba/abstract.py
+++ b/vllm/model_executor/layers/mamba/abstract.py
@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
-from typing import TYPE_CHECKING
import torch
+from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
-
class MambaBase(AttentionLayerBase):
"""
@@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
),
)
- def get_attn_backend(self) -> type["AttentionBackend"]:
+ def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)
diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py
index d85b3e61c5d61..278713408c288 100644
--- a/vllm/model_executor/layers/mamba/linear_attn.py
+++ b/vllm/model_executor/layers/mamba/linear_attn.py
@@ -8,7 +8,7 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
-from vllm.attention import AttentionMetadata
+from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 3f6ea68072b40..66945e2d2a7c8 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
-
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py
index e5a741e639ad9..1e57fa218b797 100644
--- a/vllm/model_executor/layers/quantization/bitsandbytes.py
+++ b/vllm/model_executor/layers/quantization/bitsandbytes.py
@@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `BitsAndBytesMoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
+ # TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index 6c7d4cd7bd9ab..f9d8f5883680b 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
+from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
@@ -158,9 +157,23 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
- return CompressedTensorsMoEMethod.get_moe_method(self, layer)
+ return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
return None
+ def _add_fused_moe_to_target_scheme_map(self):
+ """
+ Helper function to update target_scheme_map
+ since linear layers get fused into FusedMoE
+ targetting 'Linear' needs to also match
+ FusedMoE modules.
+ """
+ if (
+ "Linear" not in self.target_scheme_map
+ or "FusedMoE" in self.target_scheme_map
+ ):
+ return
+ self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"]
+
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
@@ -266,8 +279,9 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return []
+ @staticmethod
def _check_scheme_supported(
- self, min_capability: int, error: bool = True, match_exact: bool = False
+ min_capability: int, error: bool = True, match_exact: bool = False
) -> bool:
capability_tuple = current_platform.get_device_capability()
@@ -293,9 +307,8 @@ class CompressedTensorsConfig(QuantizationConfig):
else:
return False
- def _is_fp4a4_nvfp4(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
- ):
+ @staticmethod
+ def _is_fp4a4_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
if weight_quant is None or input_quant is None:
return False
@@ -322,9 +335,8 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric
)
- def _is_fp4a16_nvfp4(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
- ):
+ @staticmethod
+ def _is_fp4a16_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
@@ -344,8 +356,9 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric
)
+ @staticmethod
def _is_static_tensor_w8a8(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
@@ -362,8 +375,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
+ @staticmethod
def _is_dynamic_token_w8a8(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
@@ -379,8 +393,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
+ @staticmethod
def _is_dynamic_token_w4a8_int(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
@@ -403,8 +418,9 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_dynamic
)
+ @staticmethod
def _is_fp8_w8a8(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
@@ -439,8 +455,9 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
return is_symmetric_activation and is_per_tensor_activation
+ @staticmethod
def _is_fp8_w4a8(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
if not weight_quant or not input_quant:
return False
@@ -462,29 +479,33 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_dynamic
)
+ @classmethod
def _is_fp8_w4a8_sm90(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ cls, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
- return self._check_scheme_supported(
+ return cls._check_scheme_supported(
90, error=False, match_exact=True
- ) and self._is_fp8_w4a8(weight_quant, input_quant)
+ ) and cls._is_fp8_w4a8(weight_quant, input_quant)
+ @classmethod
def _is_fp8_w8a8_sm90(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ cls, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
- return self._check_scheme_supported(
+ return cls._check_scheme_supported(
90, error=False, match_exact=True
- ) and self._is_fp8_w8a8(weight_quant, input_quant)
+ ) and cls._is_fp8_w8a8(weight_quant, input_quant)
+ @classmethod
def _is_fp8_w8a8_sm100(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ cls, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
- return self._check_scheme_supported(
+ return cls._check_scheme_supported(
100, error=False, match_exact=True
- ) and self._is_fp8_w8a8(weight_quant, input_quant)
+ ) and cls._is_fp8_w8a8(weight_quant, input_quant)
+ @staticmethod
def _is_fp8_w8a16(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
# Confirm weights quantized.
if weight_quant is None:
@@ -508,8 +529,9 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_tensor_or_channel_or_block_weight
)
+ @staticmethod
def _is_wNa16_group_channel(
- self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
+ weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
input_quant_none = input_quant is None
is_channel_group = (
@@ -646,25 +668,13 @@ class CompressedTensorsConfig(QuantizationConfig):
to select the CompressedTensorsScheme used for inference.
"""
- # Find the "target" in the compressed-tensors config
- # that our layer conforms to.
- # TODO (@kylesayrs): support ignore module names with ct matching utils
- if should_ignore_layer(
- layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
- ):
- return None
+ # Use the new get_quant_args method to extract QuantizationArgs
+ scheme_dict = self.get_scheme_dict(layer, layer_name)
- # Will be empty for models with only sparsity
- weight_quant = input_quant = None
- if self.target_scheme_map:
- matched_target = find_matched_target(
- layer_name=layer_name,
- module=layer,
- targets=self.target_scheme_map.keys(),
- fused_mapping=self.packed_modules_mapping,
- )
-
- scheme_dict = self.target_scheme_map[matched_target]
+ weight_quant = None
+ input_quant = None
+ format = None
+ if scheme_dict:
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
@@ -723,6 +733,38 @@ class CompressedTensorsConfig(QuantizationConfig):
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
return scheme
+ def get_scheme_dict(
+ self, layer: torch.nn.Module, layer_name: str | None = None
+ ) -> dict[str, QuantizationArgs | str | None] | None:
+ """
+ Extract the QuantizationArgs for a given layer.
+
+ Returns:
+ dict with {
+ "weights": QuantizationArgs,
+ "input_activations": QuantizationArgs | None,
+ "format": str | None
+ } | None
+ """
+ # TODO (@kylesayrs): support ignore module names with ct matching utils
+ if should_ignore_layer(
+ layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
+ ):
+ return None
+
+ # Will be empty for models with only sparsity
+ if self.target_scheme_map:
+ matched_target = find_matched_target(
+ layer_name=layer_name,
+ module=layer,
+ targets=self.target_scheme_map.keys(),
+ fused_mapping=self.packed_modules_mapping,
+ )
+
+ return self.target_scheme_map[matched_target]
+
+ return None
+
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index ad547dd409822..c7dfd1787cc8f 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoeWeightScaleSupported,
+ UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@@ -45,9 +46,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
)
-from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
- find_matched_target,
-)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
@@ -103,7 +101,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
- "CompressedTensorsW4A4MoeMethod",
+ "CompressedTensorsW4A4Nvfp4MoeMethod",
"CompressedTensorsW4A8Int8MoEMethod",
]
@@ -113,39 +111,35 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
+ prefix: str,
) -> "CompressedTensorsMoEMethod":
+ # FusedMoE was made by combining multiple Linears so need to
+ # make sure quantization config for Linear can target it
+ quant_config._add_fused_moe_to_target_scheme_map()
+ unfused_names = [
+ prefix + proj_name
+ for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
+ ]
+ # TODO: refactor this to use expert_mapping and check all layer numbers
+ all_scheme_dicts = [
+ quant_config.get_scheme_dict(layer, name) for name in unfused_names
+ ]
+ scheme_dict = all_scheme_dicts.pop()
+
+ # multiple schemes found
+ if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
+ raise ValueError(
+ "All MoE projections need to have same "
+ "quantization scheme but found multiple"
+ )
+
+ if scheme_dict is None: # ignored layer
+ return UnquantizedFusedMoEMethod(layer.moe_config)
+
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
- # Check if a using "Linear" to select schemes
- if "Linear" in quant_config.target_scheme_map:
- matched_target = "Linear"
- else:
- # May have instead defined the linear layers in the fused model
-
- fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"]
- current_scheme = None
- for fused_layer in fused_layers:
- # Check if one of the fused layers are defined in quant_config
- matched_target = find_matched_target(
- layer_name=fused_layer,
- module=layer,
- targets=quant_config.target_scheme_map.keys(),
- fused_mapping=quant_config.packed_modules_mapping,
- )
-
- # Only valid if down_proj, gate_proj, and up_proj
- # are mapped to the same quant scheme in the quant_config
- if current_scheme is None:
- current_scheme = quant_config.target_scheme_map.get(matched_target)
- else:
- assert current_scheme == quant_config.target_scheme_map.get(
- matched_target
- )
-
- weight_quant = quant_config.target_scheme_map[matched_target].get("weights")
- input_quant = quant_config.target_scheme_map[matched_target].get(
- "input_activations"
- )
+ weight_quant = scheme_dict.get("weights")
+ input_quant = scheme_dict.get("input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise
@@ -171,7 +165,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
quant_config, layer.moe_config
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
- return CompressedTensorsW4A4MoeMethod(layer.moe_config)
+ return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
@@ -188,7 +182,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
)
-class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
+class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
@@ -205,8 +199,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
- " for CompressedTensorsW4A4MoeMethod."
+ " for CompressedTensorsW4A4Nvfp4MoeMethod."
)
+ elif self.use_marlin:
+ logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
+ else:
+ logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
def create_weights(
self,
@@ -511,7 +509,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -532,16 +530,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
- )
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
+ )
+
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@@ -554,19 +553,9 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=e_score_correction_bias,
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@@ -621,7 +610,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
- "CompressedTensorsW4A4MoeMethod."
+ "CompressedTensorsW4A4Nvfp4MoeMethod."
)
assert self.moe_quant_config is not None
@@ -1109,7 +1098,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1130,31 +1119,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
- assert isinstance(layer, FusedMoE)
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- num_fused_shared_experts=layer.num_fused_shared_experts,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
@@ -1377,7 +1344,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1398,26 +1365,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
@@ -1738,7 +1690,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1759,26 +1711,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
- )
-
assert activation == "silu", f"{activation} not supported for Marlin MoE."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
@@ -2001,7 +1938,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -2022,43 +1959,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- if expert_load_view is None:
- raise ValueError("enable_eplb=True requiere expert_load_view != None")
- if logical_to_physical_map is None:
- raise ValueError(
- "enable_eplb=True requiere logical_to_physical_map != None"
- )
- if logical_replica_count is None:
- raise ValueError(
- "enable_eplb=True requiere logical_replica_count != None"
- )
- if not isinstance(layer, FusedMoE):
- raise TypeError(
- "EPLB is only supported when `layer` is a instance of FusedMoE."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index 5241f9a2301be..7ebe40ec84687 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ExpertsInt8MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 91bd45bf879cb..7dfc8a9c36c3e 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -14,6 +14,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
+from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@@ -28,6 +29,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
@@ -118,7 +120,9 @@ class Fp8MoeBackend(Enum):
TRITON = 6
-def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
+def get_fp8_moe_backend(
+ block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
+) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
@@ -159,8 +163,19 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
- # deepGEMM on supported platforms with block-quantized weights
- if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
+ # Determine if we should use DeepGEMM with block-quantized weights:
+ # - If explicitly set by user, respect their choice
+ # - If not explicitly set (default), disable when TP size is >= 8
+ moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
+ if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
+ moe_use_deep_gemm = False
+ logger.info_once(
+ "DeepGEMM MoE is disabled by default when TP size is >= 8. "
+ "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
+ scope="local",
+ )
+
+ if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
@@ -263,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod,
XPUFp8MoEMethod,
@@ -293,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):
@@ -641,7 +653,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
- self.fp8_backend = get_fp8_moe_backend(self.block_quant)
+ self.fp8_backend = get_fp8_moe_backend(
+ self.block_quant, layer.moe_parallel_config
+ )
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
@@ -1140,7 +1154,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1216,31 +1230,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
)
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
- select_result = FusedMoE.select_experts(
+ select_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
- num_fused_shared_experts=layer.num_fused_shared_experts,
)
topk_weights, topk_ids, zero_expert_result = select_result
@@ -1322,7 +1314,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.allow_cutlass_block_scaled_grouped_gemm
),
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index 42d7a67371ae8..bcdfafb50fc5a 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
-
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
@@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_moe_gguf(
x,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 68a122fd46c6b..77b15db373a3a 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `GPTQMarlinMoEMethod` yet."
- )
-
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 01a23168bdde3..80f8e3a03e7cf 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
+from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
return self.KVCacheMethodCls(self)
@@ -696,7 +695,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -717,12 +716,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ModelOptFp8MoEMethod` yet."
- )
-
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptFp8MoEMethod` yet."
+ )
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
@@ -740,19 +738,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
# Expert selection
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
@@ -1143,6 +1131,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for ModelOptNvFp4FusedMoE."
)
+ elif self.use_marlin:
+ logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
+ else:
+ logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize(
self,
@@ -1459,7 +1451,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1480,16 +1472,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
- )
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
+ )
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@@ -1502,19 +1494,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py
index 2090c86f78dc8..cf348290a2716 100644
--- a/vllm/model_executor/layers/quantization/moe_wna16.py
+++ b/vllm/model_executor/layers/quantization/moe_wna16.py
@@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
-
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index 66ae2e94c60a5..bc241ac692e23 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
+from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@@ -132,12 +133,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
)
# If FlashInfer is not available, try either Marlin or Triton
- if (
- envs.VLLM_MXFP4_USE_MARLIN
- or current_platform.get_device_capability()[0] < 9
- or not has_triton_kernels()
- or not is_torch_equal_or_newer("2.8.0")
- ):
+ triton_kernels_supported = (
+ has_triton_kernels()
+ and is_torch_equal_or_newer("2.8.0")
+ # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
+ # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
+ # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
+ and (9, 0) <= current_platform.get_device_capability() < (11, 0)
+ )
+ if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
@@ -181,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
@@ -193,9 +195,10 @@ class Mxfp4Config(QuantizationConfig):
# TODO: Add support for MXFP4 Linear Method.
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
# if you are interested in enabling MXFP4 here.
- logger.warning_once(
+ logger.debug_once(
"MXFP4 linear layer is not implemented - falling back to "
- "UnquantizedLinearMethod."
+ "UnquantizedLinearMethod.",
+ scope="local",
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
@@ -205,9 +208,10 @@ class Mxfp4Config(QuantizationConfig):
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
- logger.warning_once(
+ logger.debug_once(
"MXFP4 attention layer is not implemented. "
- "Skipping quantization for this layer."
+ "Skipping quantization for this layer.",
+ scope="local",
)
return None
@@ -862,7 +866,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -887,18 +891,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
)
return fused_marlin_moe(
@@ -989,17 +984,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- e_score_correction_bias=e_score_correction_bias,
)
# Backend-specific preparation
diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py
index 402cebc38c215..5ccc73166361a 100644
--- a/vllm/model_executor/layers/quantization/petit.py
+++ b/vllm/model_executor/layers/quantization/petit.py
@@ -8,6 +8,7 @@ import regex as re
import torch
from torch.nn.parameter import Parameter
+from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
@@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase):
diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py
index 26ba8e5b16bc0..ed8a2c7fa0841 100644
--- a/vllm/model_executor/layers/quantization/ptpc_fp8.py
+++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py
@@ -7,6 +7,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
+from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py
index f59e5e2a0af7a..3640e5c452786 100644
--- a/vllm/model_executor/layers/quantization/quark/quark.py
+++ b/vllm/model_executor/layers/quantization/quark/quark.py
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch
+from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import Attention # Avoid circular import
-
# Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 30772c3665b06..8be0299eaa66f 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled:
@@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if not self.emulate:
diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py
index 52656263a601b..7b51b828009fc 100644
--- a/vllm/model_executor/layers/quantization/rtn.py
+++ b/vllm/model_executor/layers/quantization/rtn.py
@@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py
index 152d9401b8e94..0f10bff6ac4f5 100644
--- a/vllm/model_executor/layers/rotary_embedding/__init__.py
+++ b/vllm/model_executor/layers/rotary_embedding/__init__.py
@@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
+from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
@@ -184,6 +185,18 @@ def get_rope(
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
+ elif scaling_type == "xdrope":
+ scaling_alpha = rope_parameters["alpha"]
+ rotary_emb = XDRotaryEmbedding(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ xdrope_section=rope_parameters["xdrope_section"],
+ )
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py
new file mode 100644
index 0000000000000..2432273faf195
--- /dev/null
+++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import numpy as np
+import torch
+
+from .common import apply_rotary_emb_dispatch
+from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
+
+
+class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
+ """DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
+
+ Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
+ """
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ is_neox_style: bool,
+ scaling_alpha: float,
+ dtype: torch.dtype,
+ xdrope_section: list[int],
+ ) -> None:
+ self.xdrope_section = xdrope_section
+ super().__init__(
+ head_size,
+ rotary_dim,
+ max_position_embeddings,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """PyTorch-native implementation equivalent to forward().
+
+ Args:
+ positions:
+ [4, num_tokens] (P/W/H/T positions with multimodal inputs)
+ query: [num_tokens, num_heads * head_size]
+ key: [num_tokens, num_kv_heads * head_size]
+ """
+ assert positions.ndim == 2
+ assert key is not None
+
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ cos = torch.cat(
+ [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+ sin = torch.cat(
+ [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., : self.rotary_dim]
+ query_pass = query[..., self.rotary_dim :]
+ query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., : self.rotary_dim]
+ key_pass = key[..., self.rotary_dim :]
+ key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return query, key
+
+ @staticmethod
+ def get_next_input_positions(
+ context_len: int,
+ seq_len: int,
+ xd_sections: int = 4,
+ ) -> list[list[int]]:
+ return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
+
+ @staticmethod
+ def get_next_input_positions_tensor(
+ out: np.ndarray,
+ out_offset: int,
+ context_len: int,
+ num_new_tokens: int,
+ ):
+ values = np.arange(
+ context_len,
+ context_len + num_new_tokens,
+ dtype=out.dtype,
+ )
+ out[:, out_offset : out_offset + num_new_tokens] = values
diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py
index 2416836be03c4..74052f72ceab9 100644
--- a/vllm/model_executor/model_loader/gguf_loader.py
+++ b/vllm/model_executor/model_loader/gguf_loader.py
@@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading,
)
from vllm.model_executor.model_loader.weight_utils import (
+ download_gguf,
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
@@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
f"load format {load_config.load_format}"
)
- def _prepare_weights(self, model_name_or_path: str):
+ def _prepare_weights(self, model_config: ModelConfig):
+ model_name_or_path = model_config.model
if os.path.isfile(model_name_or_path):
return model_name_or_path
# for raw HTTPS link
@@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader):
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename)
- else:
- raise ValueError(
- f"Unrecognised GGUF reference: {model_name_or_path} "
- "(expected local file, raw URL, or /.gguf)"
+ # repo_id:quant_type
+ elif "/" in model_name_or_path and ":" in model_name_or_path:
+ repo_id, quant_type = model_name_or_path.rsplit(":", 1)
+ return download_gguf(
+ repo_id,
+ quant_type,
+ cache_dir=self.load_config.download_dir,
+ revision=model_config.revision,
+ ignore_patterns=self.load_config.ignore_patterns,
)
+ raise ValueError(
+ f"Unrecognised GGUF reference: {model_name_or_path} "
+ "(expected local file, raw URL, /.gguf, "
+ "or :)"
+ )
+
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
@@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map(
- model_config.model, gguf_to_hf_name_map
+ model_name_or_path, gguf_to_hf_name_map
)
is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal:
@@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
- self._prepare_weights(model_config.model)
+ self._prepare_weights(model_config)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
- local_model_path = self._prepare_weights(model_config.model)
+ local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
@@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
device_config = vllm_config.device_config
- local_model_path = self._prepare_weights(model_config.model)
+ local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
index 2021b68b8a60b..eeb2444150eef 100644
--- a/vllm/model_executor/model_loader/utils.py
+++ b/vllm/model_executor/model_loader/utils.py
@@ -11,8 +11,7 @@ import torch
from torch import nn
from typing_extensions import assert_never
-from vllm.attention import Attention
-from vllm.attention.layer import MLAAttention
+from vllm.attention.layer import Attention, MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 4572ebe2ea11b..0809bdfa9d4c2 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -369,6 +369,52 @@ def get_sparse_attention_config(
return config
+def download_gguf(
+ repo_id: str,
+ quant_type: str,
+ cache_dir: str | None = None,
+ revision: str | None = None,
+ ignore_patterns: str | list[str] | None = None,
+) -> str:
+ # Use patterns that snapshot_download can handle directly
+ # Patterns to match:
+ # - *-{quant_type}.gguf (root)
+ # - *-{quant_type}-*.gguf (root sharded)
+ # - */*-{quant_type}.gguf (subdir)
+ # - */*-{quant_type}-*.gguf (subdir sharded)
+ allow_patterns = [
+ f"*-{quant_type}.gguf",
+ f"*-{quant_type}-*.gguf",
+ f"*/*-{quant_type}.gguf",
+ f"*/*-{quant_type}-*.gguf",
+ ]
+
+ # Use download_weights_from_hf which handles caching and downloading
+ folder = download_weights_from_hf(
+ model_name_or_path=repo_id,
+ cache_dir=cache_dir,
+ allow_patterns=allow_patterns,
+ revision=revision,
+ ignore_patterns=ignore_patterns,
+ )
+
+ # Find the downloaded file(s) in the folder
+ local_files = []
+ for pattern in allow_patterns:
+ # Convert pattern to glob pattern for local filesystem
+ glob_pattern = os.path.join(folder, pattern)
+ local_files.extend(glob.glob(glob_pattern))
+
+ if not local_files:
+ raise ValueError(
+ f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
+ )
+
+ # Sort to ensure consistent ordering (prefer non-sharded files)
+ local_files.sort(key=lambda x: (x.count("-"), x))
+ return local_files[0]
+
+
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: str | None,
diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py
index a9cc49451a1d3..5aba46f8614be 100644
--- a/vllm/model_executor/models/adapters.py
+++ b/vllm/model_executor/models/adapters.py
@@ -428,7 +428,7 @@ def load_weights_using_from_2_way_softmax(
)
if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not
- # have this attribute, we fallback to get_input_embeddings(), which is used by
+ # have this attribute, we fall back to get_input_embeddings(), which is used by
# the Transformers modeling backend.
embed_tokens = (
model.model.embed_tokens
@@ -486,7 +486,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
)
if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not
- # have this attribute, we fallback to get_input_embeddings(), which is used by
+ # have this attribute, we fall back to get_input_embeddings(), which is used by
# the Transformers modeling backend.
embed_tokens = (
model.model.embed_tokens
diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py
index 4eb5665a71fc8..85827d54c911a 100644
--- a/vllm/model_executor/models/afmoe.py
+++ b/vllm/model_executor/models/afmoe.py
@@ -9,7 +9,8 @@ from itertools import islice
import torch
from torch import nn
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py
index b75e91319bbad..f38b09bf55068 100644
--- a/vllm/model_executor/models/apertus.py
+++ b/vllm/model_executor/models/apertus.py
@@ -32,7 +32,8 @@ import torch
from torch import nn
from transformers import ApertusConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py
index b75a254761d4e..266d29a8d9b2b 100644
--- a/vllm/model_executor/models/arctic.py
+++ b/vllm/model_executor/models/arctic.py
@@ -8,7 +8,7 @@ from itertools import islice
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index edf47270e5277..beb22995a0719 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -29,7 +29,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
@@ -233,7 +233,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py
index cc10e936a2d3d..f7a5d4e7889e5 100644
--- a/vllm/model_executor/models/bailing_moe.py
+++ b/vllm/model_executor/models/bailing_moe.py
@@ -32,7 +32,7 @@ import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index 00fba93423d8e..507fbf1fdd0a8 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -27,7 +27,7 @@ import torch
from torch import nn
from transformers import BloomConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index b5a6d00dc309f..3aa01bb1905fe 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -16,7 +16,7 @@ from transformers import (
ChameleonVQVAEConfig,
)
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index dbfcd62d0bcab..3d485fdd0a2e1 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -12,7 +12,7 @@ import torch
from torch import nn
from torch.nn import LayerNorm
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py
index 5d611deb942d1..c2993b47dc3f9 100644
--- a/vllm/model_executor/models/clip.py
+++ b/vllm/model_executor/models/clip.py
@@ -14,8 +14,7 @@ from transformers import (
CLIPVisionConfig,
)
-from vllm.attention import Attention
-from vllm.attention.layer import MultiHeadAttention
+from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index 5ed920927c772..f837502c468f1 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -30,7 +30,7 @@ import torch
from torch import nn
from transformers import Cohere2Config, CohereConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py
index 3cf4bf991e667..d7e802ba1aca0 100644
--- a/vllm/model_executor/models/config.py
+++ b/vllm/model_executor/models/config.py
@@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
- if cache_config.mamba_block_size is None:
- cache_config.mamba_block_size = model_config.max_model_len
-
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
logger.info(
@@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
)
+ # By default, mamba block size will be set to max_model_len (see
+ # below). When enabling prefix caching, we align mamba block size
+ # to the block size as the basic granularity for prefix caching.
+ if cache_config.mamba_block_size is None:
+ cache_config.mamba_block_size = cache_config.block_size
else:
logger.info(
"Hybrid or mamba-based model detected without "
@@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
)
cache_config.enable_prefix_caching = False
+ if cache_config.mamba_block_size is None:
+ cache_config.mamba_block_size = model_config.max_model_len
+
# TODO(tdoublep): remove once cascade attention is supported
logger.info(
"Disabling cascade attention since it is not supported for hybrid models."
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 2c729019081a4..946baffc8817a 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -8,7 +8,7 @@ import torch
import torch.nn as nn
from transformers import DbrxConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py
index e028dc497aa6a..6e23037b919ab 100644
--- a/vllm/model_executor/models/deepseek_mtp.py
+++ b/vllm/model_executor/models/deepseek_mtp.py
@@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Iterable
+import typing
+from collections.abc import Callable, Iterable
import torch
import torch.nn as nn
from transformers import PretrainedConfig
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ rocm_aiter_moe_shared_expert_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
@@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
- num_experts=self.config.n_routed_experts,
+ num_experts=self.config.n_routed_experts
+ + (
+ self.config.n_shared_experts
+ if rocm_aiter_moe_shared_expert_enabled
+ else 0
+ ),
)
params_dict = dict(self.named_parameters())
@@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
+ )
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
+ if is_fusion_moe_shared_experts_layer:
+ continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
@@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight, shard_id)
break
else:
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(
- param,
- loaded_weight,
- name,
- shard_id=shard_id,
- expert_id=expert_id,
+ # Special handling: when AITER fusion_shared_experts is enabled,
+ # checkpoints may provide a single widened shared_experts tensor
+ # without explicit expert indices
+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
+ # For models with multiple shared experts, split that tensor
+ # evenly into per-shared-expert slices and load them into
+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
+ # accordingly.
+ num_chunks = 1
+ if is_fusion_moe_shared_experts_layer:
+ num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
+ # Determine split axis based on op type
+ # gate/up: ColumnParallel → split along dim 0
+ # down: RowParallel → split along dim 1
+ split_dim = 1 if "down_proj.weight" in name else 0
+ total = loaded_weight.shape[split_dim]
+ assert total % num_chunks == 0, (
+ f"Shared expert weight dim {total} "
+ f"not divisible by num_chunks {num_chunks}"
)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
+ chunk_size = total // num_chunks
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
+ for j in range(num_chunks):
+ chunk_name = name
+ weight_to_load = loaded_weight
- # According to DeepSeek-V3 Technical Report, MTP modules
- # shares embedding layer. We only load the first weights.
- if (
- spec_layer != self.model.mtp_start_layer_idx
- and ".layers" not in name
- ):
- continue
+ if is_fusion_moe_shared_experts_layer:
+ if split_dim == 0:
+ weight_to_load = loaded_weight[
+ j * chunk_size : (j + 1) * chunk_size, :
+ ]
+ else:
+ weight_to_load = loaded_weight[
+ :, j * chunk_size : (j + 1) * chunk_size
+ ]
+ # Synthesize an expert-style name so expert mapping
+ # can route it
+ chunk_name = name.replace(
+ "mlp.shared_experts",
+ f"mlp.experts.{self.config.n_routed_experts + j}",
+ )
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
+ # Use expert_params_mapping to locate the destination
+ # param and delegate to its expert-aware weight_loader
+ # with expert_id.
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in chunk_name:
+ continue
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = chunk_name.replace(weight_name, param_name)
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or
+ # not here since otherwise we may skip experts with
+ # other available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ weight_to_load,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ if not is_fusion_moe_shared_experts_layer:
+ name = name_mapped
+ else:
+ loaded_params.add(name_mapped)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ # According to DeepSeek-V3 Technical Report, MTP modules
+ # shares embedding layer. We only load the first weights.
+ if (
+ spec_layer != self.model.mtp_start_layer_idx
+ and ".layers" not in name
+ ):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ if not is_fusion_moe_shared_experts_layer:
+ loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 7cfd381592b49..73cac2556c55a 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -33,8 +33,8 @@ from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm._aiter_ops import rocm_aiter_ops
-from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
+from vllm.attention.layer import Attention
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
@@ -1479,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
- is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
- "mlp.shared_experts" in name
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1495,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
@@ -1531,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
@@ -1548,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name
weight_to_load = loaded_weight
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
@@ -1599,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True,
)
if success:
- if not is_fuse_shared_experts_layer:
+ if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
@@ -1628,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
- if not is_fuse_shared_experts_layer:
+ if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params
diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py
index e7b48e0f4e554..1b6e4110039c4 100644
--- a/vllm/model_executor/models/deepseek_vl2.py
+++ b/vllm/model_executor/models/deepseek_vl2.py
@@ -48,7 +48,6 @@ from vllm.transformers_utils.configs.deepseek_vl2 import (
)
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
-from vllm.utils.collection_utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
@@ -595,19 +594,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input(
self, image_input: DeepseekVL2ImageInputs
- ) -> list[torch.Tensor]:
+ ) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
- image_data = image_input["data"]
- if is_list_of(image_data, torch.Tensor):
- # it's already a list of tensors
- return image_data
- if len(image_data.shape) == 3:
- # 3D tensor
- return list(torch.unbind(image_data, dim=0))
- raise ValueError(
- "We expect batched 2D tensors; "
- "this can be either a list of 2D tensors or a single 3D tensor."
- )
+ return image_input["data"]
pixel_values = image_input["data"]
images_spatial_crop = image_input["images_spatial_crop"]
diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py
index e65c275106a4e..1c2abbe7b3a78 100644
--- a/vllm/model_executor/models/dots1.py
+++ b/vllm/model_executor/models/dots1.py
@@ -32,7 +32,7 @@ import torch
from torch import nn
from transformers import Dots1Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py
index 2d2251e83b5b1..5460018d0d67a 100644
--- a/vllm/model_executor/models/dots_ocr.py
+++ b/vllm/model_executor/models/dots_ocr.py
@@ -306,7 +306,6 @@ class DotsVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -324,7 +323,6 @@ class DotsVisionAttention(nn.Module):
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
- seqlens: list[int] | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@@ -374,16 +372,6 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
else:
raise RuntimeError("Unsupported attention backend")
@@ -545,14 +533,12 @@ class DotsVisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None,
- seqlens: list[int] | None = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -663,18 +649,14 @@ class DotsVisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
@@ -694,14 +676,13 @@ class DotsVisionTransformer(nn.Module):
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if self.post_trunk_norm is not None:
diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py
index a7df3509e3ecd..278ba45e9684c 100644
--- a/vllm/model_executor/models/ernie45_moe.py
+++ b/vllm/model_executor/models/ernie45_moe.py
@@ -32,7 +32,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py
index daa5bf03ea4a9..07b34fbc8addb 100644
--- a/vllm/model_executor/models/ernie45_vl.py
+++ b/vllm/model_executor/models/ernie45_vl.py
@@ -214,7 +214,6 @@ class Ernie4_5_VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -259,7 +258,6 @@ class Ernie4_5_VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -311,20 +309,6 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -404,14 +388,12 @@ class Ernie4_5_VisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -562,18 +544,14 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
@@ -598,8 +576,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
if hidden_states.ndim == 2:
hidden_states = hidden_states.unsqueeze(dim=1)
- # pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for i, blk in enumerate(self.blocks):
hidden_states = blk(
@@ -607,7 +585,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
final_output = self.ln(hidden_states)
diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py
index 50e033d77606d..72f9957fc8828 100644
--- a/vllm/model_executor/models/ernie45_vl_moe.py
+++ b/vllm/model_executor/models/ernie45_vl_moe.py
@@ -31,7 +31,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
# from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py
index d13275488fe99..99002baa87529 100644
--- a/vllm/model_executor/models/exaone.py
+++ b/vllm/model_executor/models/exaone.py
@@ -32,7 +32,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py
index 70f3cce2b7c56..9d2c67d6c4f80 100644
--- a/vllm/model_executor/models/exaone4.py
+++ b/vllm/model_executor/models/exaone4.py
@@ -28,7 +28,7 @@ import torch
from torch import nn
from transformers import Exaone4Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index dc2d51f340c8c..32d9e7b925597 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -30,7 +30,7 @@ from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 00c7f59a08094..dd5a74c8ed005 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -26,7 +26,7 @@ import torch
from torch import nn
from transformers import GemmaConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index 9b6cfe6932300..cb36e04824588 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -23,7 +23,7 @@ import torch
from torch import nn
from transformers import Gemma2Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py
index 4ad6fc89dcaf2..73176eba95ed5 100644
--- a/vllm/model_executor/models/gemma3.py
+++ b/vllm/model_executor/models/gemma3.py
@@ -23,7 +23,8 @@ import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py
index 8f1447ba34a81..f4427c9fd1d10 100644
--- a/vllm/model_executor/models/gemma3n.py
+++ b/vllm/model_executor/models/gemma3n.py
@@ -21,7 +21,7 @@ import torch
from torch import nn
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py
index f8ef3b0385fb1..002cdb721e1db 100644
--- a/vllm/model_executor/models/glm4.py
+++ b/vllm/model_executor/models/glm4.py
@@ -29,7 +29,8 @@ import torch
from torch import nn
from transformers import Glm4Config
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py
index d141e95498064..7e0370886884f 100644
--- a/vllm/model_executor/models/glm4_1v.py
+++ b/vllm/model_executor/models/glm4_1v.py
@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -803,15 +785,14 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ ) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -836,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
- # pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
@@ -851,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
# adapter
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
index 5aa51af54a00b..c99f824e1bd4d 100644
--- a/vllm/model_executor/models/glm4_moe.py
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -31,7 +31,7 @@ import torch
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index a5e8131c7fba9..da5d48a94ff3e 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -27,7 +27,7 @@ import torch
from torch import nn
from transformers import GPT2Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import (
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index cdf038ba25c92..a405fd184513f 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -28,7 +28,7 @@ import torch
from torch import nn
from transformers import GPTBigCodeConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index e94de8952fa63..f0a34c47da54c 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -26,7 +26,7 @@ import torch
from torch import nn
from transformers import GPTJConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -100,7 +100,7 @@ class GPTJAttention(nn.Module):
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=False,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index 815c2fba4d9fe..b9959682cbcef 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -26,7 +26,7 @@ import torch
from torch import nn
from transformers import GPTNeoXConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index 1bc0ad38765d5..9de3e261941b1 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -7,7 +7,8 @@ import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py
index cd7ce2fc8f00a..eac9ef9478a6a 100644
--- a/vllm/model_executor/models/granite.py
+++ b/vllm/model_executor/models/granite.py
@@ -31,7 +31,7 @@ import torch
from torch import nn
from transformers import GraniteConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py
index 8f4139d63c3f6..02c6c5862141f 100644
--- a/vllm/model_executor/models/granitemoe.py
+++ b/vllm/model_executor/models/granitemoe.py
@@ -31,7 +31,7 @@ from typing import Any
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py
index 4bf23cd6fd19a..6f62a1d11e52e 100644
--- a/vllm/model_executor/models/grok1.py
+++ b/vllm/model_executor/models/grok1.py
@@ -31,7 +31,7 @@ import torch
import torch.nn.functional as F
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -239,7 +239,7 @@ class Grok1DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py
index 9fa5e2bd33f21..ccdfa3fe175f1 100644
--- a/vllm/model_executor/models/hunyuan_v1.py
+++ b/vllm/model_executor/models/hunyuan_v1.py
@@ -33,7 +33,8 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
@@ -576,7 +577,16 @@ class HunYuanDecoderLayer(nn.Module):
return hidden_states, residual, ori_kv_states
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py
new file mode 100644
index 0000000000000..2950db571e6ee
--- /dev/null
+++ b/vllm/model_executor/models/hunyuan_vision.py
@@ -0,0 +1,1028 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# coding=utf-8
+# Copyright 2025 The HunYuan team.
+# Copyright 2025 The vLLM team.
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only HunYuan-VL model compatible with HuggingFace weights."""
+
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from functools import partial
+from typing import Annotated, Any, Literal, TypeAlias
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import BatchFeature
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.layer import MultiHeadAttention
+from vllm.config import MultiModalConfig, VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ ImageItem,
+ ModalityData,
+ MultiModalDataDict,
+ MultiModalFeatureSpec,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import (
+ DictEmbeddingItems,
+ ImageSize,
+ MultiModalDataItems,
+ MultiModalDataParser,
+)
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLVisionConfig,
+)
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+ MultiModalEmbeddings,
+ SupportsLoRA,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+)
+from .utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+# === Vision Inputs === #
+
+
+class HunYuanVLImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - np: Number of patches
+ - ni: Number of images
+ - cps: Number of channels * patch_size * patch_size
+ """
+
+ type: Literal["pixel_values"]
+
+ pixel_values: Annotated[
+ torch.Tensor,
+ TensorShape("np", "cps"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+class HunYuanVLImageEmbeddingInputs(TensorSchema):
+ """
+ Dimensions:
+ - nf: Number of image features
+ - hs: Hidden size
+ - ni: Number of images
+ """
+
+ type: Literal["image_embeds"]
+
+ image_embeds: Annotated[
+ torch.Tensor,
+ TensorShape("nf", "hs"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+HunYuanVLImageInputs: TypeAlias = (
+ HunYuanVLImagePixelInputs | HunYuanVLImageEmbeddingInputs
+)
+
+# === Vision Encoder === #
+
+
+class HunYuanVisionMLP(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = True,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.dense_h_to_4h = ColumnParallelLinear(
+ in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_h_to_4h",
+ disable_tp=use_data_parallel,
+ )
+ self.dense_4h_to_h = RowParallelLinear(
+ hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_4h_to_h",
+ disable_tp=use_data_parallel,
+ )
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ x_up, _ = self.dense_h_to_4h(x)
+ x_down, _ = self.dense_4h_to_h(self.act_fn(x_up))
+ return x_down
+
+
+class HunYuanVisionAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ self.tp_size = (
+ 1
+ if use_data_parallel
+ else parallel_state.get_tensor_model_parallel_world_size()
+ )
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads
+ )
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, self.tp_size
+ )
+
+ self.qkv = QKVParallelLinear(
+ hidden_size=embed_dim,
+ head_size=self.hidden_size_per_attention_head,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv",
+ disable_tp=use_data_parallel,
+ )
+
+ self.o_proj = RowParallelLinear(
+ input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ disable_tp=use_data_parallel,
+ )
+
+ self.scale = self.hidden_size_per_attention_head**-0.5
+ self.attn = MultiHeadAttention(
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ self.scale,
+ prefix=f"{prefix}.attn",
+ multimodal_config=multimodal_config,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv(x)
+ q, k, v = qkv.chunk(3, dim=-1)
+ out = self.attn(q, k, v)
+ output, _ = self.o_proj(out)
+ return output
+
+
+class HunYuanVisionBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ norm_layer: Callable[[int], nn.Module] | None = None,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.input_layernorm = norm_layer(dim)
+ self.post_attention_layernorm = norm_layer(dim)
+ self.self_attn = HunYuanVisionAttention(
+ embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.self_attn",
+ use_data_parallel=use_data_parallel,
+ )
+ self.mlp = HunYuanVisionMLP(
+ dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ use_data_parallel=use_data_parallel,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ x = x + self.self_attn(self.input_layernorm(x))
+ x = x + self.mlp(self.post_attention_layernorm(x))
+ return x
+
+
+class HunYuanVisionPatchEmbed(nn.Module):
+ def __init__(self, config: HunYuanVLVisionConfig):
+ super().__init__()
+
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+ self.num_channels = config.num_channels
+ self.spatial_merge_size = config.spatial_merge_size
+ self.interpolate_mode = config.interpolate_mode
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.max_num_patches = (config.max_image_size // self.patch_size) ** 2
+
+ self.num_positions = self.max_num_patches + 1
+ self.position_edge = int(self.num_positions**0.5)
+ # first token is cls token, skip it
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ self.patch_pos_embed = None
+
+ def forward(
+ self, pixel_values: torch.Tensor, grid_thw: list[list[int]]
+ ) -> torch.Tensor:
+ num_patches = pixel_values.size(0)
+ pixel_values = pixel_values.reshape(
+ num_patches, self.num_channels, self.patch_size, self.patch_size
+ )
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ patch_embeds = patch_embeds.squeeze(-1).squeeze(-1).unsqueeze(0)
+
+ if self.patch_pos_embed is None:
+ patch_pos_shape = (
+ 1,
+ self.position_edge,
+ self.position_edge,
+ self.embed_dim,
+ )
+ self.patch_pos_embed = (
+ self.position_embedding.weight[1:, :]
+ .reshape(patch_pos_shape)
+ .permute(0, 3, 1, 2)
+ .float()
+ )
+
+ patch_pos_embed_list = []
+ for grid in grid_thw:
+ _, h0, w0 = grid
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ self.patch_pos_embed,
+ scale_factor=(h0 / self.position_edge, w0 / self.position_edge),
+ mode=self.interpolate_mode,
+ align_corners=False,
+ )
+
+ patch_pos_embed = (
+ patch_pos_embed.reshape(self.embed_dim, -1)
+ .transpose(0, 1)
+ .unsqueeze(0)
+ .to(patch_embeds.dtype)
+ )
+ patch_pos_embed_list.append(patch_pos_embed)
+
+ patch_pos_embed = torch.cat(patch_pos_embed_list, dim=1)
+ embeddings = patch_embeds + patch_pos_embed
+
+ return embeddings
+
+
+class HunYuanVisionPatchMerger(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ spatial_merge_size=2,
+ rms_norm_eps=1e-5,
+ prefix="",
+ ):
+ super().__init__()
+ self.spatial_merge_size = spatial_merge_size
+ embed_std = out_channels**-0.5
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ in_channels * 2,
+ kernel_size=spatial_merge_size,
+ stride=spatial_merge_size,
+ ),
+ nn.GELU(),
+ nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
+ )
+ self.mlp = nn.Linear(in_channels * 4, out_channels)
+
+ self.image_newline = nn.Parameter(torch.randn(in_channels * 4) * embed_std)
+ self.image_begin = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_end = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_sep = nn.Parameter(torch.randn(out_channels) * embed_std)
+
+ self.before_rms = RMSNorm(in_channels, eps=rms_norm_eps)
+ self.after_rms = RMSNorm(out_channels, eps=rms_norm_eps)
+
+ def forward(self, x, size=(16, 16)):
+ x = self.before_rms(x)
+
+ h, w = size
+ dtype = x.dtype
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
+
+ x = self.proj(x) # b,c,h,w
+ b, c, h, w = x.shape
+ x = torch.cat(
+ [x, self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype)],
+ dim=-1,
+ )
+ x = x.reshape(b, c, -1).permute(0, 2, 1)
+ x = self.mlp(x)
+
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ x = torch.cat([begin, x, end], dim=1)
+
+ return self.after_rms(x)
+
+
+class HunYuanVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ vision_config: HunYuanVLVisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ multimodal_config: MultiModalConfig | None = None,
+ attn_backend_override: AttentionBackendEnum | None = None,
+ ) -> None:
+ super().__init__()
+
+ num_hidden_layers = vision_config.num_hidden_layers
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_attention_heads
+ self.spatial_merge_size = vision_config.spatial_merge_size
+
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("HunYuanVisionPatchEmbed"):
+ self.embeddings = HunYuanVisionPatchEmbed(vision_config)
+
+ norm_layer = partial(nn.LayerNorm, eps=vision_config.rms_norm_eps)
+
+ with set_model_tag("HunYuanVisionBlock"):
+ self.layers = nn.ModuleList(
+ [
+ HunYuanVisionBlock(
+ dim=vision_config.hidden_size,
+ num_heads=vision_config.num_attention_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=get_act_fn(vision_config.hidden_act),
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ use_data_parallel=use_data_parallel,
+ )
+ for layer_idx in range(num_hidden_layers)
+ ]
+ )
+
+ with set_model_tag("HunYuanVisionPatchMerger"):
+ self.perceive = HunYuanVisionPatchMerger(
+ vision_config.hidden_size,
+ vision_config.out_hidden_size,
+ spatial_merge_size=vision_config.spatial_merge_size,
+ rms_norm_eps=vision_config.rms_norm_eps,
+ prefix=f"{prefix}.perceive",
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.embeddings.patch_embedding.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.embeddings.patch_embedding.weight.device
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: list[list[int]],
+ ) -> torch.Tensor:
+ # patchify
+ seq_len = x.size(0)
+ cu_seqlens: list = [0]
+
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.embeddings(hidden_states, grid_thw)
+
+ for t, h, w in grid_thw:
+ t, h, w = int(t), int(h), int(w)
+ cu_seqlens.append(h * w)
+
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
+ cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
+
+ cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
+
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ hidden_states = hidden_states.unsqueeze(0)
+ for layer_num, layer in enumerate(self.layers):
+ hidden_states = layer(hidden_states)
+
+ # adapter
+ split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ split_items = hidden_states.split(split_lengths, dim=1)
+ image_embeds_list = []
+ for grid, split_item in zip(grid_thw, split_items):
+ image_embeds_list.append(
+ self.perceive(split_item.contiguous(), size=grid[1:]).squeeze(0)
+ )
+
+ return image_embeds_list
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv", ".q_proj", "q"),
+ (".qkv", ".k_proj", "k"),
+ (".qkv", ".v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ image_grid_sizes = image_grid_thw.prod(-1)
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_grid_thw=MultiModalFieldConfig.batched("image"),
+ )
+
+
+class HunYuanVLMultiModalDataParser(MultiModalDataParser):
+ def _parse_image_data(
+ self,
+ data: dict[str, torch.Tensor] | ModalityData[ImageItem],
+ ):
+ if isinstance(data, dict):
+ return DictEmbeddingItems(
+ data,
+ modality="image",
+ required_fields={"image_embeds", "image_grid_thw"},
+ fields_factory=_hunyuan_vl_field_config,
+ )
+
+ return super()._parse_image_data(data)
+
+
+class HunYuanVLProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(HunYuanVLConfig)
+
+ def get_hf_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.ctx.get_hf_processor(
+ HunYuanVLProcessor,
+ use_fast=kwargs.pop("use_fast", True),
+ **kwargs,
+ )
+
+ def get_image_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ max_image_tokens = self.get_max_image_tokens()
+ # TODO: support video
+ max_video_tokens = 0
+ return {"image": max_image_tokens, "video": max_video_tokens}
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ do_resize: bool = True,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> tuple[ImageSize, int]:
+ if image_processor is None:
+ image_processor = self.get_image_processor()
+
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ patch_size = vision_config.patch_size
+ spatial_merge_size = vision_config.spatial_merge_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * spatial_merge_size,
+ min_pixels=image_processor.min_pixels,
+ max_pixels=image_processor.max_pixels,
+ )
+ preprocessed_size = ImageSize(width=resized_width, height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width, height=image_height)
+
+ grid_t = 1
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_vision_tokens = (
+ grid_t * grid_h // spatial_merge_size * (grid_w // spatial_merge_size + 1)
+ + 2
+ )
+
+ return preprocessed_size, num_vision_tokens
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> int:
+ _, num_image_tokens = self._get_vision_info(
+ image_width=image_width,
+ image_height=image_height,
+ image_processor=image_processor,
+ )
+ return num_image_tokens
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ max_image_size, _ = self._get_vision_info(
+ image_width=512,
+ image_height=8192,
+ image_processor=None,
+ )
+ return max_image_size
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ image_processor=None,
+ )
+
+
+class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ hf_processor = self.info.get_hf_processor()
+ image_token: str = hf_processor.image_token
+
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 1)
+
+ target_width, target_height = self.info.get_image_size_with_most_features()
+
+ return {
+ "image": self._get_dummy_images(
+ width=target_width, height=target_height, num_images=num_images
+ ),
+ }
+
+
+class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]):
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return HunYuanVLMultiModalDataParser()
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ return self.info.ctx.call_hf_processor(
+ self.info.get_hf_processor(**mm_kwargs),
+ dict(text=prompt, **mm_data),
+ dict(**mm_kwargs, **tok_kwargs),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
+
+ placeholder = {
+ "image": hf_processor.image_token_id,
+ }
+
+ merge_size = image_processor.merge_size
+
+ def get_replacement_hunyuan_vl(item_idx: int, modality: str):
+ out_item = out_mm_kwargs[modality][item_idx]
+ grid_thw = out_item[f"{modality}_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+
+ _, grid_h, grid_w = grid_thw
+ num_tokens = (int(grid_h) // merge_size) * (
+ int(grid_w) // merge_size + 1
+ ) + 2
+ return [placeholder[modality]] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=[placeholder[modality]],
+ replacement=partial(get_replacement_hunyuan_vl, modality=modality),
+ )
+ for modality in ("image",)
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _hunyuan_vl_field_config(hf_inputs)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ HunYuanVLMultiModalProcessor,
+ info=HunYuanVLProcessingInfo,
+ dummy_inputs=HunYuanVLDummyInputsBuilder,
+)
+class HunYuanVLForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsLoRA,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+):
+ multimodal_cpu_fields = {"image_grid_thw"}
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # mapping for new names in checkpoint saved after transformers v4.52
+ "vit.vit.": "visual.",
+ "vit.": "visual.",
+ "model.": "language_model.model.",
+ }
+ )
+
+ supports_encoder_tp_data = True
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list[MultiModalFeatureSpec],
+ ) -> torch.Tensor:
+ kwargs = MultiModalFeatureSpec.gather_kwargs(
+ mm_features,
+ {"image_grid_thw"},
+ )
+ image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
+
+ hf_config = self.config
+ image_start_token_id = hf_config.image_start_token_id
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
+ xd_num = len(hf_config.rope_scaling["xdrope_section"])
+
+ input_tokens_tensor = torch.tensor(input_tokens)
+ image_start_indices = torch.argwhere(
+ input_tokens_tensor == image_start_token_id
+ ).squeeze(1)
+
+ p_index = torch.arange(len(input_tokens_tensor))
+ w_index = torch.arange(len(input_tokens_tensor))
+ h_index = torch.arange(len(input_tokens_tensor))
+ t_index = torch.arange(len(input_tokens_tensor))
+ for image_index in range(len(image_start_indices)):
+ # +1 : first image_token, +2: for xdrope positions
+ pos = image_start_indices[image_index] + 2
+ t, h, w = image_grid_thw[image_index]
+ _, llm_grid_h, llm_grid_w = (
+ t,
+ h // spatial_merge_size,
+ w // spatial_merge_size,
+ )
+
+ token_num = (llm_grid_w + 1) * llm_grid_h
+ w_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_w + 1)
+ .reshape(1, -1)
+ .expand(llm_grid_h, -1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_h)
+ .reshape(-1, 1)
+ .expand(-1, llm_grid_w + 1)
+ .reshape(-1)
+ )
+ t_index[pos : pos + token_num] = image_index
+
+ if xd_num == 4:
+ llm_positions = torch.stack([p_index, w_index, h_index, t_index])
+ elif xd_num == 3:
+ llm_positions = torch.stack([w_index, h_index, t_index])
+
+ return llm_positions
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: HunYuanVLConfig = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ if multimodal_config.get_limit_per_prompt("image"):
+ attn_backend_override = (
+ multimodal_config.mm_encoder_attn_backend
+ if multimodal_config is not None
+ else None
+ )
+ self.visual = HunYuanVisionTransformer(
+ config.vision_config,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ multimodal_config=multimodal_config,
+ attn_backend_override=attn_backend_override,
+ )
+ else:
+ self.visual = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model.model"),
+ architectures=[
+ "HunYuanDenseV1ForCausalLM",
+ "HunYuanMoEV1ForCausalLM",
+ ],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object
+ ) -> HunYuanVLImageInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ # TODO: refine
+ if isinstance(pixel_values, list):
+ pixel_values = torch.cat(pixel_values, dim=0)
+ if len(pixel_values.shape) == 3:
+ last_dim = pixel_values.shape[-1]
+ pixel_values = pixel_values.reshape(-1, last_dim)
+ image_grid_thw = image_grid_thw.reshape(-1, 3)
+
+ if pixel_values is not None:
+ return HunYuanVLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ if image_embeds is not None:
+ return HunYuanVLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw,
+ )
+
+ def _process_image_input(
+ self, image_input: HunYuanVLImageInputs
+ ) -> tuple[torch.Tensor, ...]:
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+ grid_thw_list = grid_thw.tolist()
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"]
+
+ # TODO: use_data_parallel (split image_embeds in visual)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
+
+ return image_embeds
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if (
+ input_key in ("pixel_values", "image_embeds")
+ and "image" not in mm_input_by_modality
+ ):
+ mm_input_by_modality["image"] = self._parse_and_validate_image_input(
+ **kwargs
+ )
+ return mm_input_by_modality
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ image_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings += tuple(image_embeddings)
+ return multimodal_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None,
+ inputs_embeds: torch.Tensor | None,
+ **kwargs: object,
+ ) -> torch.Tensor | IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model.model",
+ connector="visual.perceive",
+ tower_model="visual",
+ )
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index 9966498e1b4c9..cee0b79e5e5ac 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
+ is_3d_moe_weight: ClassVar[bool] = False
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {}
@@ -1047,7 +1048,7 @@ class SupportsMRoPE(Protocol):
supports_mrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports M-RoPE.
-
+
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
@@ -1088,3 +1089,52 @@ def supports_mrope(
model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
return isinstance(model, SupportsMRoPE)
+
+
+@runtime_checkable
+class SupportsXDRoPE(Protocol):
+ """The interface required for all models that support XD-RoPE."""
+
+ supports_xdrope: ClassVar[Literal[True]] = True
+ """
+ A flag that indicates this model supports XD-RoPE.
+
+ Note:
+ There is no need to redefine this flag if this class is in the
+ XDRope of your model class.
+ """
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list["MultiModalFeatureSpec"],
+ ) -> torch.Tensor:
+ """
+ Get XD-RoPE input positions and delta value for this specific model.
+
+ This method should be implemented by each model that supports XD-RoPE
+ to provide model-specific logic for computing input positions.
+
+ Args:
+ input_tokens: List of input token IDs
+ mm_features: Information about each multi-modal data item
+
+ Returns:
+ llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
+ 4D(P/W/H/T) or 3D(W/H/T) positions.
+ """
+ ...
+
+
+@overload
+def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
+
+
+@overload
+def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
+
+
+def supports_xdrope(
+ model: type[object] | object,
+) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
+ return isinstance(model, SupportsXDRoPE)
diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py
index 4267b6c6598e2..85c5574bacf0a 100644
--- a/vllm/model_executor/models/interfaces_base.py
+++ b/vllm/model_executor/models/interfaces_base.py
@@ -167,8 +167,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
default_pooling_type: ClassVar[str] = "LAST"
"""
- Indicates the
- [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
+ Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
to use by default.
You can use the
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index dc8f821bd134f..c79934e121447 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -10,7 +10,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index 5549a1fc1cd30..6012288814f15 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -28,7 +28,7 @@ from itertools import islice
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py
index 8fc3db296aa79..302260b952992 100644
--- a/vllm/model_executor/models/keye.py
+++ b/vllm/model_executor/models/keye.py
@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
- AttentionBackendEnum.XFORMERS,
+ AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0]
if rope_emb is None:
@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
+ elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
+ outputs = []
+ for i in range(1, len(cu_seqlens)):
+ start_idx = cu_seqlens[i - 1]
+ end_idx = cu_seqlens[i]
+ q_i = q[:, start_idx:end_idx]
+ k_i = k[:, start_idx:end_idx]
+ v_i = v[:, start_idx:end_idx]
+ q_i, k_i, v_i = (
+ rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
+ )
+ output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
+ output_i = rearrange(output_i, "b h s d -> b s h d ")
+ outputs.append(output_i)
+ context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py
index 74bdde27ece5c..69615f8b6a099 100644
--- a/vllm/model_executor/models/lfm2.py
+++ b/vllm/model_executor/models/lfm2.py
@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from transformers import Lfm2Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py
index c088a08211527..aaeb2cc38999e 100644
--- a/vllm/model_executor/models/lfm2_moe.py
+++ b/vllm/model_executor/models/lfm2_moe.py
@@ -6,7 +6,7 @@ from itertools import islice
import torch
import torch.nn as nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index ebf8addda4a54..6dfbde7a17f54 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -31,7 +31,8 @@ import torch
from torch import nn
from transformers import LlamaConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -262,7 +263,7 @@ class LlamaAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- rope_parameters=config.rope_parameters,
+ rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
@@ -354,7 +355,17 @@ class LlamaDecoderLayer(nn.Module):
return vllm_config.quant_config
-@support_torch_compile
+def llama_model_invariants(
+ input_ids, positions, intermediate_tensors=None, inputs_embeds=None
+):
+ """Shape invariants for Llama model compilation, those are translated to
+ runtime assertions for unbacked dynamic shapes and are compiled away for
+ backed"""
+ if input_ids is not None:
+ torch._check(positions.size()[0] == input_ids.size()[0])
+
+
+@support_torch_compile(shape_invariants=llama_model_invariants)
class LlamaModel(nn.Module):
def __init__(
self,
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
index e1bdfc3405f70..423be45e80149 100644
--- a/vllm/model_executor/models/llama4.py
+++ b/vllm/model_executor/models/llama4.py
@@ -24,7 +24,7 @@ import torch
from torch import nn
from transformers import Llama4TextConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 3eaf2d80082f1..7a57644db1b13 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -142,6 +142,12 @@ class LlamaModel(nn.Module):
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
+ eagle_config = getattr(self.config, "eagle_config", None)
+ if eagle_config is not None and "use_aux_hidden_state" in eagle_config:
+ self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
+ else:
+ self.use_aux_hidden_state = True
+
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@@ -161,20 +167,20 @@ class LlamaModel(nn.Module):
for layer_idx in range(self.config.num_hidden_layers)
]
)
- if hasattr(self.config, "target_hidden_size"):
- fc_input_size = self.config.target_hidden_size * 3
- else:
- fc_input_size = self.config.hidden_size * 3
- self.fc = ReplicatedLinear(
- input_size=fc_input_size,
- output_size=self.config.hidden_size,
- bias=False,
- params_dtype=vllm_config.model_config.dtype,
- quant_config=self.quant_config,
- prefix=maybe_prefix(prefix, "fc"),
- return_bias=False,
- )
-
+ if self.use_aux_hidden_state:
+ if hasattr(self.config, "target_hidden_size"):
+ fc_input_size = self.config.target_hidden_size * 3
+ else:
+ fc_input_size = self.config.hidden_size * 3
+ self.fc = ReplicatedLinear(
+ input_size=fc_input_size,
+ output_size=self.config.hidden_size,
+ bias=False,
+ params_dtype=vllm_config.model_config.dtype,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "fc"),
+ return_bias=False,
+ )
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@@ -332,6 +338,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
+ if not self.model.use_aux_hidden_state:
+ return hidden_states
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
@@ -357,6 +365,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
+ if not self.model.use_aux_hidden_state:
+ skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index 98b1b46045c3d..b995cac47ac1c 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -460,7 +460,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
image_input: LlavaNextImageInputs,
) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
- return [image_input["data"]]
+ return image_input["data"]
patch_embeddings = self._process_image_pixels(image_input)
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index 322bde94ff66d..4e243ade68358 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -763,7 +763,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
image_input: LlavaOnevisionImageInputs,
) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
- return [image_input["data"]]
+ return image_input["data"]
patch_embeddings = self._process_image_pixels(image_input)
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 04923833065f3..67911ba8c1c8f 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -33,7 +33,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py
index 2d775219fc972..0a2bcbd7f6084 100644
--- a/vllm/model_executor/models/minicpm3.py
+++ b/vllm/model_executor/models/minicpm3.py
@@ -29,7 +29,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py
index 4955c68c0cda8..dd98e36ec0851 100644
--- a/vllm/model_executor/models/minimax_m2.py
+++ b/vllm/model_executor/models/minimax_m2.py
@@ -30,7 +30,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py
index 50f7396e2de60..390de78cc27b4 100644
--- a/vllm/model_executor/models/minimax_text_01.py
+++ b/vllm/model_executor/models/minimax_text_01.py
@@ -14,7 +14,8 @@ import torch
from torch import nn
from transformers import MiniMaxConfig
-from vllm.attention import Attention, AttentionMetadata
+from vllm.attention.backends.abstract import AttentionMetadata
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import (
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index 0a9c3f136964e..e21656dbd6350 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -32,7 +32,7 @@ import torch
from torch import nn
from transformers import MixtralConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index dc06938d5d6e1..7b53299cccbe4 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -17,8 +17,7 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
-from vllm.attention import Attention
-from vllm.attention.layer import MultiHeadAttention
+from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py
index 2e3e6dc166ad8..63ea6b259a71d 100644
--- a/vllm/model_executor/models/moonvit.py
+++ b/vllm/model_executor/models/moonvit.py
@@ -56,10 +56,13 @@ from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
+from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
+elif current_platform.is_xpu():
+ from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
@@ -106,10 +109,10 @@ def multihead_attention(
q,
k,
v,
- q_cu_seqlens,
- k_cu_seqlens,
- max_seqlen_q,
- max_seqlen_k,
+ cu_seqlens_q=q_cu_seqlens,
+ cu_seqlens_k=k_cu_seqlens,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
@@ -291,7 +294,12 @@ class Rope2DPosEmb(nn.Module):
"""
def __init__(
- self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
+ self,
+ dim: int,
+ max_height: int,
+ max_width: int,
+ theta_base=10000,
+ device=current_platform.device_type,
):
super().__init__()
self.dim = dim
@@ -437,7 +445,7 @@ class MoonVitEncoderLayer(nn.Module):
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
# use fa2 in vllm by default
- if is_flash_attn_2_available():
+ if is_flash_attn_2_available() or current_platform.is_xpu():
self.attn_implementation = "flash_attention_2"
self.norm0 = nn.LayerNorm(hidden_dim)
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index 106ad971a321a..1e285646b9ec3 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from transformers import MptConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py
index c3337bd1ea699..93ad2064a2fca 100644
--- a/vllm/model_executor/models/nemotron.py
+++ b/vllm/model_executor/models/nemotron.py
@@ -30,7 +30,7 @@ from itertools import islice
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py
index 2eebe38051cbd..34ea2945b711e 100644
--- a/vllm/model_executor/models/nemotron_nas.py
+++ b/vllm/model_executor/models/nemotron_nas.py
@@ -31,7 +31,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
-from vllm.attention import AttentionType
+from vllm.attention.backends.abstract import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index bd8a8e317544f..3bbb4dd242262 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -31,7 +31,7 @@ import torch
from torch import nn
from transformers import OlmoConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py
index f0f6b2f6b3e6d..88e9c2d8541a1 100644
--- a/vllm/model_executor/models/olmo2.py
+++ b/vllm/model_executor/models/olmo2.py
@@ -32,7 +32,7 @@ import torch
from torch import nn
from transformers import Olmo2Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py
index c39e338d72e22..1376583a99725 100644
--- a/vllm/model_executor/models/olmoe.py
+++ b/vllm/model_executor/models/olmoe.py
@@ -21,7 +21,7 @@ from itertools import islice
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py
index 4124a181a14c2..bddd9fa50957a 100644
--- a/vllm/model_executor/models/openpangu.py
+++ b/vllm/model_executor/models/openpangu.py
@@ -29,7 +29,8 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 5df700d1a2e17..bba5291ea5ef5 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -27,7 +27,7 @@ import torch
from torch import nn
from transformers import OPTConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index b30be93ca726f..544a44ed54681 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -15,7 +15,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py
index 63d2fff6ec8bc..dcae92ed20881 100644
--- a/vllm/model_executor/models/ouro.py
+++ b/vllm/model_executor/models/ouro.py
@@ -33,7 +33,8 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py
index dee0c16ab0f63..74bb868492da9 100644
--- a/vllm/model_executor/models/paddleocr_vl.py
+++ b/vllm/model_executor/models/paddleocr_vl.py
@@ -38,7 +38,6 @@ from vllm.attention.layer import (
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
- vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
- seqlens: torch.Tensor | None,
) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape
@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- if seqlens is None:
- raise ValueError("xFormers attention backend requires seqlens tensor.")
- context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
- seqlens: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = residual + hidden_states
@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None
- seqlens = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
return hidden_states
diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py
index 98963d52e4848..795cd25f16753 100644
--- a/vllm/model_executor/models/persimmon.py
+++ b/vllm/model_executor/models/persimmon.py
@@ -30,7 +30,7 @@ import torch
from torch import nn
from transformers import PersimmonConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index da476f621627b..70016d9ed246c 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -45,7 +45,7 @@ import torch
from torch import nn
from transformers import PhiConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py
index 8ffac95d93960..a5a669139b2f7 100644
--- a/vllm/model_executor/models/phimoe.py
+++ b/vllm/model_executor/models/phimoe.py
@@ -31,7 +31,7 @@ import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index 8a034fd72b02a..3464de472add5 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -74,6 +74,7 @@ from .vision import (
)
try:
+ # Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):
@@ -399,21 +400,30 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
prefix=maybe_prefix(prefix, "language_model"),
)
- self.vision_encoder = VisionTransformer(self.vision_args)
-
- if self.vision_args.add_pre_mm_projector_layer_norm:
- self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5)
-
- if self.vision_args.mm_projector_id == PATCH_MERGE:
- self.patch_merger = PatchMerger(
- vision_encoder_dim=self.vision_args.hidden_size,
- spatial_merge_size=self.vision_args.spatial_merge_size,
- use_mlp_bias=False,
+ if multimodal_config.get_limit_per_prompt("image"):
+ self.vision_encoder = VisionTransformer(self.vision_args)
+ self.pre_mm_projector_norm = (
+ RMSNorm(self.vision_args.hidden_size, eps=1e-5)
+ if self.vision_args.add_pre_mm_projector_layer_norm
+ else None
)
-
- self.vision_language_adapter = VisionLanguageAdapter(
- self.vision_args, dim=config.text_config.hidden_size
- )
+ self.patch_merger = (
+ PatchMerger(
+ vision_encoder_dim=self.vision_args.hidden_size,
+ spatial_merge_size=self.vision_args.spatial_merge_size,
+ use_mlp_bias=False,
+ )
+ if self.vision_args.mm_projector_id == PATCH_MERGE
+ else None
+ )
+ self.vision_language_adapter = VisionLanguageAdapter(
+ self.vision_args, dim=config.text_config.hidden_size
+ )
+ else:
+ self.vision_encoder = None
+ self.pre_mm_projector_norm = None
+ self.patch_merger = None
+ self.vision_language_adapter = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -435,13 +445,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
+ assert (
+ self.vision_encoder is not None and self.vision_language_adapter is not None
+ )
+
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
image_features = torch.cat(image_features)
- if self.vision_args.add_pre_mm_projector_layer_norm:
+ if self.pre_mm_projector_norm is not None:
image_features = self.pre_mm_projector_norm(image_features)
- if self.vision_args.mm_projector_id == PATCH_MERGE:
+ if self.patch_merger is not None:
patch_size = self.vision_args.patch_size
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
img_patch_dims = [
@@ -507,41 +521,57 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading
- vision_encoder_dict = dict(self.vision_encoder.named_parameters())
+ vision_encoder_dict = (
+ dict(self.vision_encoder.named_parameters())
+ if self.vision_encoder is not None
+ else {}
+ )
patch_merger_dict = (
dict(self.patch_merger.named_parameters())
- if self.vision_args.mm_projector_id == PATCH_MERGE
- else dict()
+ if self.patch_merger is not None
+ else {}
)
pre_mm_projector_norm_dict = (
dict(self.pre_mm_projector_norm.named_parameters())
- if self.vision_args.add_pre_mm_projector_layer_norm
- else dict()
+ if self.pre_mm_projector_norm is not None
+ else {}
+ )
+ vision_lang_adapter_dict = (
+ dict(self.vision_language_adapter.named_parameters())
+ if self.vision_language_adapter is not None
+ else {}
)
- vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters())
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
+ if self.vision_encoder is None:
+ continue
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
+ if self.patch_merger is None:
+ continue
# Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
+ if self.pre_mm_projector_norm is None:
+ continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
+ if self.vision_language_adapter is None:
+ continue
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index c973e79170982..12285cf9c1968 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -16,7 +16,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 32b6d6dd07b83..34c31d8deee23 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -33,7 +33,8 @@ import torch
from torch import nn
from transformers import Qwen2Config
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -274,6 +275,38 @@ class Qwen2DecoderLayer(nn.Module):
return hidden_states, residual
+def qwen_2_model_invariants(
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+):
+ """Shape invariants for Qwen2Model Model, those are translated to
+ runtime assertions for unbacked dynamic shapes and are compiled away for
+ backed"""
+ # All these should be equal.
+ # input_ids.size()[0]
+ # positions.size()[-1]
+ # intermediate_tensors["hidden_states"].size()[0]
+ # inputs_embeds.size()[0]
+ torch._check(input_ids.size()[0] == positions.size()[-1])
+ if intermediate_tensors is not None:
+ torch._check(
+ input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
+ )
+
+ if inputs_embeds is not None:
+ torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
+
+ # Hidden dimensions should match (hidden_size)
+ # intermediate_tensors["hidden_states"].size()[1]
+ # inputs_embeds.size()[1]
+ if inputs_embeds is not None and intermediate_tensors is not None:
+ torch._check(
+ inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
+ )
+
+
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
@@ -282,7 +315,8 @@ class Qwen2DecoderLayer(nn.Module):
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
- }
+ },
+ shape_invariants=qwen_2_model_invariants,
)
class Qwen2Model(nn.Module):
def __init__(
diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py
index 262ea771d9cdf..7506ee8656fda 100644
--- a/vllm/model_executor/models/qwen2_5_omni_thinker.py
+++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py
@@ -23,7 +23,6 @@
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence
-from copy import copy
from functools import partial
from typing import Annotated, Any, Literal
@@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
- use_audio_in_video = False
- if "video" in mm_kwargs:
- video_items = [item for item in mm_kwargs["video"] if item is not None]
- # only check video items (if there are any)
- if video_items:
- use_audio_in_video = all(
- item["use_audio_in_video"].data for item in video_items
- )
-
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
@@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
- use_audio_in_video=use_audio_in_video,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
@@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
- use_audio_in_video=use_audio_in_video,
)
return prompt_ids, mm_placeholders
@@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
return mm_processed_data
- def _validate_mm_placeholders(
- self,
- mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
- mm_item_counts: Mapping[str, int],
- use_audio_in_video: bool = False,
- ) -> None:
- if use_audio_in_video:
- mm_item_counts = copy(mm_item_counts)
- if "video" in mm_item_counts:
- assert "audio" in mm_item_counts
- mm_item_counts["audio"] -= mm_item_counts["video"]
- super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
-
class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_audio_input(
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 1500a437613cc..8c707c2561af1 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
- vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -375,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -435,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v,
cu_seqlens,
)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
return output
@@ -448,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
- "seqlens": 0,
},
- mark_unbacked_dims={"seqlens": 0},
enable_if=should_torch_compile_mm_vit,
)
class Qwen2_5_VisionBlock(nn.Module):
@@ -501,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -509,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -670,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -822,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
@@ -897,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
- max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
- max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
- cu_window_seqlens
- )
+ max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
@@ -927,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
- seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
- seqlens_now = seqlens_window
hidden_states = blk(
hidden_states,
@@ -939,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
- seqlens=seqlens_now,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 6b97d0b2ca2e3..5a428740082f6 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -34,7 +34,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import Qwen2MoeConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 479a7871e364f..9d1d023aed172 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -348,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -384,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x)
@@ -445,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -509,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -517,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -728,18 +710,14 @@ class Qwen2VisionTransformer(nn.Module):
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -771,7 +749,7 @@ class Qwen2VisionTransformer(nn.Module):
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks:
x = blk(
@@ -780,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
# adapter
diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py
index 93a629d81e8ff..7d2b3e5f9bc79 100644
--- a/vllm/model_executor/models/qwen3.py
+++ b/vllm/model_executor/models/qwen3.py
@@ -30,7 +30,8 @@ import torch
from torch import nn
from transformers import Qwen3Config
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index 8ee3dd99e11db..6f520706a3176 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -31,7 +31,7 @@ from typing import Any
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py
index bfed64728305e..661a182151d74 100644
--- a/vllm/model_executor/models/qwen3_next.py
+++ b/vllm/model_executor/models/qwen3_next.py
@@ -10,7 +10,8 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
-from vllm.attention import Attention, AttentionMetadata
+from vllm.attention.backends.abstract import AttentionMetadata
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
index 54ef56f83344e..f5f88f66eff91 100755
--- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py
+++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
@@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
- BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
+ PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
@@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
- Qwen2_5OmniThinkerProcessingInfo,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
@@ -224,7 +223,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -232,7 +230,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -500,14 +497,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -533,7 +527,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@@ -545,7 +539,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if (
deepstack_visual_indexes is not None
@@ -813,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
else:
use_audio_in_video = False
- if use_audio_in_video and "video" in mm_item_counts:
- assert "audio" in mm_item_counts
- mm_item_counts["audio"] -= mm_item_counts["video"]
-
- # Special case with `use_audio_in_video=True`
- if use_audio_in_video:
- if is_update_applied:
- prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
- (
- prompt_ids,
- mm_placeholders,
- ) = self._apply_prompt_updates(
- prompt_ids,
- mm_prompt_updates,
- )
- self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
# normal case with `use_audio_in_video=False`
- elif is_update_applied:
+ if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
mm_prompt_updates,
@@ -840,10 +817,24 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts,
)
else:
- prompt_ids, mm_placeholders = self._apply_prompt_updates(
- prompt_ids,
- mm_prompt_updates,
- )
+ if use_audio_in_video and "audio" in mm_prompt_updates:
+ filtered_updates = {
+ k: v for k, v in mm_prompt_updates.items() if k != "audio"
+ }
+ prompt_ids, mm_placeholders = self._apply_prompt_updates(
+ prompt_ids,
+ filtered_updates,
+ )
+ # Derive audio placeholders from video placeholders
+ mm_placeholders = self._derive_audio_from_video_placeholders(
+ mm_placeholders, mm_prompt_updates
+ )
+ else:
+ prompt_ids, mm_placeholders = self._apply_prompt_updates(
+ prompt_ids,
+ mm_prompt_updates,
+ )
+
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
@@ -968,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
nonlocal audio_in_video_item_idx
- audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
+ audio_num_features = audio_output_lengths[
+ audio_in_video_item_idx + item_idx
+ ]
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
audio_in_video_item_idx += 1
@@ -977,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[item_idx]
else:
- video_second_per_grid_t = 1.0
+ video_second_per_grid_t = 2.0
- return self.get_updates_use_audio_in_video(
+ placeholder = self.get_updates_use_audio_in_video(
thinker_config=thinker_config,
audio_len=audio_num_features,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
+ return PromptUpdateDetails.select_token_id(
+ placeholder, embed_token_id=video_token_id
+ )
video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video
@@ -1010,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
),
]
- def _validate_mm_placeholders(
+ def _derive_audio_from_video_placeholders(
self,
- mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
- mm_item_counts: Mapping[str, int],
- ) -> None:
- BaseMultiModalProcessor[
- Qwen2_5OmniThinkerProcessingInfo
- ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)
+ placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
+ mm_prompt_updates: MultiModalPromptUpdates,
+ ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
+ """
+ Helper to derive audio placeholders from video placeholders when
+ use_audio_in_video=True.
+ """
+ if "video" not in placeholders:
+ return placeholders
+
+ # Validate audio and video counts match
+ num_videos = len(placeholders["video"])
+ num_audios = len(mm_prompt_updates.get("audio", []))
+ if num_audios != num_videos:
+ raise ValueError(
+ f"use_audio_in_video requires equal number of audio and video items, "
+ f"got {num_audios=}, {num_videos=}"
+ )
+
+ tokenizer = self.info.get_tokenizer()
+ processor = self.info.get_hf_processor()
+ audio_token_id = tokenizer.get_vocab()[processor.audio_token]
+
+ result_placeholders = dict(placeholders)
+ audio_placeholders = []
+
+ # Each video is paired with one audio
+ for video_idx, video_placeholder in enumerate(placeholders["video"]):
+ # Create is_embed mask selecting only audio tokens
+ audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
+
+ audio_placeholder = PlaceholderFeaturesInfo(
+ modality="audio",
+ item_idx=video_idx,
+ start_idx=video_placeholder.start_idx,
+ tokens=video_placeholder.tokens,
+ is_embed=audio_is_embed,
+ )
+ audio_placeholders.append(audio_placeholder)
+
+ result_placeholders["audio"] = audio_placeholders
+ return result_placeholders
def _get_raw_input_ids(
self,
@@ -1460,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
if not len(second_per_grid_ts) and len(video_grid_thw):
- second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
+ second_per_grid_ts = 2.0
+ second_per_grids = (
+ torch.ones(len(video_grid_thw), dtype=torch.float32)
+ * second_per_grid_ts
+ )
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py
index 90c4894d33e88..4cd6fa14c32df 100644
--- a/vllm/model_executor/models/qwen3_vl.py
+++ b/vllm/model_executor/models/qwen3_vl.py
@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -391,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -531,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -569,7 +563,7 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
@@ -580,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py
index e2c129120b1a5..a054bd5b3831e 100644
--- a/vllm/model_executor/models/qwen3_vl_moe.py
+++ b/vllm/model_executor/models/qwen3_vl_moe.py
@@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
class Qwen3VLMoeForConditionalGeneration(
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
):
+ is_3d_moe_weight: bool = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index b3da64af750c7..ba9f33819c950 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
-import hashlib
import importlib
import json
import os
@@ -32,6 +31,7 @@ from vllm.config import (
from vllm.logger import init_logger
from vllm.logging_utils import logtime
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
+from vllm.utils.hashing import safe_hash
from .interfaces import (
has_inner_state,
@@ -170,6 +170,7 @@ _TEXT_GENERATION_MODELS = {
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
+ "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
@@ -207,6 +208,7 @@ _EMBEDDING_MODELS = {
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
+ "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
# [Multimodal]
@@ -287,6 +289,10 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration",
),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
+ "HunYuanVLForConditionalGeneration": (
+ "hunyuan_vision",
+ "HunYuanVLForConditionalGeneration",
+ ),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
"OpenCUAForConditionalGeneration": (
@@ -650,7 +656,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
if model_path.exists():
with open(model_path, "rb") as f:
- module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
+ module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
mi = self._load_modelinfo_from_cache(module_hash)
if mi is not None:
diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py
index 4744d8e44f390..267c60157506d 100644
--- a/vllm/model_executor/models/seed_oss.py
+++ b/vllm/model_executor/models/seed_oss.py
@@ -30,7 +30,8 @@ import torch
from torch import nn
from transformers import PretrainedConfig as SeedOssConfig
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py
index 7e9fc51036d2e..c576154b1ecfd 100644
--- a/vllm/model_executor/models/solar.py
+++ b/vllm/model_executor/models/solar.py
@@ -30,7 +30,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index a738fcbb4ee28..6cb98b7b72a5b 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -29,7 +29,7 @@ import torch
from torch import nn
from transformers import StableLmConfig
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 1118fca3cac91..46422f303ff43 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -28,7 +28,7 @@ import torch
from torch import nn
from transformers import Starcoder2Config
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py
index 3c377a2c539df..077cce84a98dd 100644
--- a/vllm/model_executor/models/step3_text.py
+++ b/vllm/model_executor/models/step3_text.py
@@ -9,7 +9,7 @@ from typing import Any
import torch
from torch import nn
-from vllm.attention import Attention
+from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py
index f4ba4758bcc46..b33ce35427f5e 100644
--- a/vllm/model_executor/models/transformers/base.py
+++ b/vllm/model_executor/models/transformers/base.py
@@ -27,7 +27,8 @@ from torch import nn
from transformers import AutoModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
-from vllm.attention import Attention, AttentionType
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index bb0f6bd036f14..26a8355cd22b5 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -116,7 +116,12 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
+
+ # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
audio_processor = hf_processor.audio_processor # type: ignore
+ if isinstance(audio_processor, WhisperFeatureExtractor):
+ return audio_processor
+
feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index ccefd7e66697f..f25ab9153a50d 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -10,7 +10,6 @@ import torch
import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
-from typing_extensions import deprecated
from vllm.config import VllmConfig
from vllm.distributed import (
@@ -481,54 +480,6 @@ def _merge_multimodal_embeddings(
return inputs_embeds
-@deprecated(
- "`merge_multimodal_embeddings` has been replaced with "
- "`SupportsMultiModal.embed_input_ids` and will be "
- "removed in v0.12."
-)
-def merge_multimodal_embeddings(
- input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor,
- multimodal_embeddings: NestedTensors,
- placeholder_token_id: int | list[int],
-) -> torch.Tensor:
- """
- Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
- positions in `inputs_embeds` corresponding to placeholder tokens in
- `input_ids`.
-
- `placeholder_token_id` can be a list of token ids (e.g, token ids
- of img_start, img_break, and img_end tokens) when needed: This means
- the order of these tokens in the `input_ids` MUST MATCH the order of
- their embeddings in `multimodal_embeddings` since we need to
- slice-merge instead of individually scattering.
-
- For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- - T is text token
- - S is image start token
- - I is image embedding token
- - B is image break token
- - E is image end token.
-
- Then the image embeddings (that correspond to I's) from vision encoder
- must be padded with embeddings of S, B, and E in the same order of
- input_ids for a correct embedding merge.
-
- Note:
- This updates `inputs_embeds` in place.
- """
- if isinstance(placeholder_token_id, list):
- is_multimodal = isin_list(input_ids, placeholder_token_id)
- else:
- is_multimodal = input_ids == placeholder_token_id
-
- return _merge_multimodal_embeddings(
- inputs_embeds,
- multimodal_embeddings=multimodal_embeddings,
- is_multimodal=is_multimodal,
- )
-
-
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py
index 50587c627160d..c72b5e1c091f2 100644
--- a/vllm/model_executor/models/whisper.py
+++ b/vllm/model_executor/models/whisper.py
@@ -16,8 +16,8 @@ from transformers import (
)
from transformers.models.whisper.modeling_whisper import sinusoids
-from vllm.attention import Attention, AttentionType
-from vllm.attention.layer import MultiHeadAttention
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index ed655912d3964..5f9561366e0d5 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -14,6 +14,7 @@ import regex as re
import torch
from vllm import envs
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import CpuArchEnum, Platform, PlatformEnum
@@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
- AttentionBackendEnum = None
VllmConfig = None
@@ -135,8 +134,6 @@ class CpuPlatform(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
- from vllm.attention.backends.registry import AttentionBackendEnum
-
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index f9bf242b7194e..4bf9401b6b051 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
+from vllm.attention.backends.abstract import AttentionType
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
@@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
- AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
@@ -48,8 +48,6 @@ def _get_backend_priorities(
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
- from vllm.attention.backends.registry import AttentionBackendEnum
-
if use_mla:
if device_capability.major == 10:
return [
@@ -265,24 +263,18 @@ class CudaPlatformBase(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
- from vllm.attention.backends.registry import AttentionBackendEnum
-
# Try FlashAttention first
- try:
- backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
- if backend_class.supports_head_size(
- head_size
- ) and backend_class.supports_dtype(dtype):
- return AttentionBackendEnum.FLASH_ATTN
- except ImportError:
- pass
+ if (cc := cls.get_device_capability()) and cc.major >= 8:
+ try:
+ backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
+ if backend_class.supports_head_size(
+ head_size
+ ) and backend_class.supports_dtype(dtype):
+ return AttentionBackendEnum.FLASH_ATTN
+ except ImportError:
+ pass
- if cls.has_device_capability(100):
- # xFormers doesn't support Blackwell, fall back to SDPA
- # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
- return AttentionBackendEnum.TORCH_SDPA
- else:
- return AttentionBackendEnum.XFORMERS
+ return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(
@@ -340,8 +332,6 @@ class CudaPlatformBase(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
- from vllm.attention import AttentionType
-
if attn_type is None:
attn_type = AttentionType.DECODER
@@ -412,9 +402,6 @@ class CudaPlatformBase(Platform):
# We have found some valid backends. Select the one with the
# highest priority.
- logger.info(
- "Valid backends: %s", [b[0].name for b in valid_backends_priorities]
- )
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
@@ -422,8 +409,9 @@ class CudaPlatformBase(Platform):
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info(
- "Using %s backend.",
+ "Using %s attention backend out of potential backends: %s",
selected_backend.name,
+ [b[0].name for b in valid_backends_priorities],
)
return selected_backend.get_path()
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 0471c20429b1d..27c6fac09f498 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np
import torch
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
@@ -134,6 +134,11 @@ class Platform:
_global_graph_pool: Any | None = None
+ @property
+ def pass_key(self) -> str:
+ """Inductor config key for the PassManager custom pass"""
+ return "post_grad_custom_post_pass"
+
@property
def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform."""
@@ -177,6 +182,21 @@ class Platform:
# all ROCm platforms for now.
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
+ @classmethod
+ def get_pass_manager_cls(cls) -> str:
+ """
+ Get the pass manager class for this platform.
+ It will be registered as a custom pass under the current_platform.pass_key.
+ """
+ return "vllm.compilation.pass_manager.PostGradPassManager"
+
+ @classmethod
+ def get_compile_backend(cls) -> str:
+ """
+ Get the custom compile backend for current platform.
+ """
+ return cls.simple_compile_backend
+
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):
# Treat empty device control env var as unset. This is a valid
@@ -206,9 +226,6 @@ class Platform:
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
- # Import AttentionBackendEnum here to avoid circular import.
- from vllm.attention.backends.registry import AttentionBackendEnum
-
return AttentionBackendEnum.TORCH_SDPA
@classmethod
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index f9005fd7d044c..ccf3446a3a6e5 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
-else:
- AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -196,7 +194,6 @@ class RocmPlatform(Platform):
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
- from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
@@ -222,7 +219,6 @@ class RocmPlatform(Platform):
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
- from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
@@ -264,28 +260,66 @@ class RocmPlatform(Platform):
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
- return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
- if (
- rocm_aiter_ops.is_mha_enabled()
- ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
- logger.info("Using Aiter Flash Attention backend.")
- return AttentionBackendEnum.ROCM_AITER_FA.get_path()
- if (
- rocm_aiter_ops.is_triton_unified_attn_enabled()
- ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
- logger.info("Using Aiter Unified Attention backend.")
- return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
- if (
- envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
- or selected_backend == AttentionBackendEnum.ROCM_ATTN
- ):
- # rocm specific backend, with aiter and/or
- # triton prefix-prefill
- logger.info("Using Rocm Attention backend.")
+ return AttentionBackendEnum.FLEX_ATTENTION.get_path()
+
+ if selected_backend == AttentionBackendEnum.TRITON_ATTN:
+ logger.info("Using Triton Attention backend on V1 engine.")
+ return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ if selected_backend == AttentionBackendEnum.ROCM_ATTN:
+ logger.info("Using Rocm Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
- # default case, using triton unified attention
- logger.info("Using Triton Attention backend.")
- return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
+ if on_gfx9():
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+ else:
+ raise ValueError(
+ f"The selected backend, {selected_backend.name}, "
+ "is only supported on gfx9 architectures."
+ )
+
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
+ logger.info("Using Aiter Unified Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
+
+ # Handle automatic backend selection based on environment variables
+ if selected_backend is None:
+ # Priority 1: Check for AITER Unified Attention (must check before MHA)
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
+ logger.info("Using Aiter Unified Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
+
+ # Priority 2: Check for AITER MHA (Flash Attention)
+ # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+
+ # Priority 3: Check for ROCM_ATTN (prefill-decode split)
+ if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
+ logger.info("Using Rocm Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_ATTN.get_path()
+
+ # Priority 4: Check for AITER enabled without specific flags
+ # This defaults to AITER FA only if MHA is not explicitly disabled
+ if (
+ envs.VLLM_ROCM_USE_AITER
+ and on_gfx9()
+ and envs.VLLM_ROCM_USE_AITER_MHA is not False
+ ):
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+
+ # Default: Triton Unified Attention
+ logger.info("Using Triton Attention backend on V1 engine.")
+ return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ raise RuntimeError(
+ f"Attention backend {selected_backend.name} is not supported on "
+ "ROCm. Note that V0 attention backends have been removed."
+ )
@classmethod
def set_device(cls, device: torch.device) -> None:
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index 944344a229578..cbc0a996f3661 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
import torch
from tpu_info import device
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
@@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
@@ -26,7 +26,6 @@ else:
BlockSize = None
VllmConfig = None
PoolingParams = None
- AttentionBackendEnum = None
ParamsType = None
logger = init_logger(__name__)
@@ -67,8 +66,6 @@ class TpuPlatform(Platform):
use_sparse,
attn_type: str | None = None,
) -> str:
- from vllm.attention.backends.registry import AttentionBackendEnum
-
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
@@ -267,7 +264,9 @@ class TpuPlatform(Platform):
try:
- from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
+ from tpu_inference.platforms import (
+ TpuPlatform as TpuInferencePlatform,
+ )
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index 18a3186b142f1..768714fb16726 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
- from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
VllmConfig = None
- AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -60,8 +59,6 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
- from vllm.attention.backends.registry import AttentionBackendEnum
-
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
@@ -116,8 +113,6 @@ class XPUPlatform(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
- from vllm.attention.backends.registry import AttentionBackendEnum
-
return AttentionBackendEnum.FLASH_ATTN
@classmethod
diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py
index 5c3dfa8ac9cbc..c2094a2d920a2 100644
--- a/vllm/pooling_params.py
+++ b/vllm/pooling_params.py
@@ -57,7 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
- skip_reading_prefix_cache: bool = None
+ skip_reading_prefix_cache: bool | None = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@@ -219,6 +219,7 @@ class PoolingParams(
f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, "
+ f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
f"extra_kwargs={self.extra_kwargs})"
)
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index fbbe3d4cabb9a..8de961e62db1b 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -3,7 +3,6 @@
"""Sampling parameters for text generation."""
import copy
-import warnings
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
@@ -100,19 +99,6 @@ class StructuredOutputsParams:
)
-@dataclass
-class GuidedDecodingParams(StructuredOutputsParams):
- def __post_init__(self):
- warnings.warn(
- "GuidedDecodingParams is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "StructuredOutputsParams instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- return super().__post_init__()
-
-
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@@ -234,8 +220,6 @@ class SamplingParams(
# Fields used to construct logits processors
structured_outputs: StructuredOutputsParams | None = None
"""Parameters for configuring structured outputs."""
- guided_decoding: GuidedDecodingParams | None = None
- """Deprecated alias for structured_outputs."""
logit_bias: dict[int, float] | None = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
@@ -254,7 +238,7 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
- skip_reading_prefix_cache: bool = None
+ skip_reading_prefix_cache: bool | None = None
@staticmethod
def from_optional(
@@ -283,7 +267,6 @@ class SamplingParams(
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None,
- guided_decoding: GuidedDecodingParams | None = None,
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
@@ -295,16 +278,6 @@ class SamplingParams(
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
- if guided_decoding is not None:
- warnings.warn(
- "guided_decoding is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "structured_outputs instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- structured_outputs = guided_decoding
- guided_decoding = None
return SamplingParams(
n=1 if n is None else n,
@@ -387,17 +360,6 @@ class SamplingParams(
# eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids)
- if self.guided_decoding is not None:
- warnings.warn(
- "guided_decoding is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "structured_outputs instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- self.structured_outputs = self.guided_decoding
- self.guided_decoding = None
-
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 3d282da8c6112..66680f410cb3c 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -42,7 +42,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import (
check_gguf_file,
+ is_gguf,
+ is_remote_gguf,
parse_safetensors_file_metadata,
+ split_remote_gguf,
)
if envs.VLLM_USE_MODELSCOPE:
@@ -86,6 +89,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
+ hunyuan_vl="HunYuanVLConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
@@ -452,51 +456,55 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No
def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
- # Retrieve rope_parameters differently based on Transformers version
+ # Patch rope_parameters differently based on Transformers version
if Version(version("transformers")) >= Version("5.0.0.dev0"):
- from transformers.modeling_rope_utils import RopeParameters
-
- rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
- config, "rope_parameters", None
+ from transformers.modeling_rope_utils import (
+ rope_config_validation,
+ standardize_rope_params,
)
- elif hasattr(config, "rope_parameters"):
- # We are in Transformers v4 and rope_parameters
- # has already been patched for this config
- return
+
+ # When Transformers v5 is installed, legacy rope_theta may be present
+ # when using custom code models written for Transformers v4
+ if (rope_theta := getattr(config, "rope_theta", None)) is not None:
+ standardize_rope_params(config, rope_theta=rope_theta)
+ rope_config_validation(config)
+ # Delete rope_theta to avoid confusion in downstream code
+ del config.rope_theta
else:
- # Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
- rope_theta: float | None = getattr(config, "rope_theta", None)
- rope_scaling: dict | None = getattr(config, "rope_scaling", None)
- rope_parameters = rope_scaling
- # Move rope_theta into rope_parameters
- if rope_theta is not None:
- rope_parameters = rope_parameters or {"rope_type": "default"}
- rope_parameters["rope_theta"] = rope_theta
- # Add original_max_position_embeddings if present
- if rope_parameters and (
- ompe := getattr(config, "original_max_position_embeddings", None)
- ):
- rope_parameters["original_max_position_embeddings"] = ompe
- # Write back to config
- config.rope_parameters = rope_parameters
+ # When Transformers v4 is installed, legacy rope_scaling may be present
+ if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
+ config.rope_parameters = rope_scaling
+ # When Transformers v4 is installed, legacy rope_theta may be present
+ if (rope_theta := getattr(config, "rope_theta", None)) is not None:
+ if not hasattr(config, "rope_parameters"):
+ config.rope_parameters = {"rope_type": "default"}
+ config.rope_parameters["rope_theta"] = rope_theta
# No RoPE parameters to patch
- if rope_parameters is None:
+ if not hasattr(config, "rope_parameters"):
return
+ # Add original_max_position_embeddings if present
+ if ompe := getattr(config, "original_max_position_embeddings", None):
+ config.rope_parameters["original_max_position_embeddings"] = ompe
+
# Handle nested rope_parameters in interleaved sliding attention models
- if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
- for rope_parameters_layer_type in rope_parameters.values():
+ if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
+ for rope_parameters_layer_type in config.rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type)
else:
- patch_rope_parameters_dict(rope_parameters)
+ patch_rope_parameters_dict(config.rope_parameters)
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
if "rope_type" in rope_parameters and "type" in rope_parameters:
rope_type = rope_parameters["rope_type"]
rope_type_legacy = rope_parameters["type"]
- if rope_type != rope_type_legacy:
+ if (rope_type_legacy == "su" and rope_type == "longrope") or (
+ rope_type_legacy == "mrope" and rope_type == "default"
+ ):
+ pass # No action needed
+ elif rope_type != rope_type_legacy:
raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern "
f"field) and 'type={rope_type_legacy}' (legacy field). "
@@ -549,6 +557,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
return uses_mrope(thinker_text_config)
+def uses_xdrope_dim(config: PretrainedConfig) -> int:
+ """Detect if the model with this config uses XD-ROPE."""
+ xdrope_section = getattr(config, "xdrope_section", None)
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling is None:
+ return 0
+
+ if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
+ xdrope_section = rope_scaling["xdrope_section"]
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+
+ return 0
+
+
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
@@ -611,10 +636,12 @@ def maybe_override_with_speculators(
Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
"""
- is_gguf = check_gguf_file(model)
- if is_gguf:
+ if check_gguf_file(model):
kwargs["gguf_file"] = Path(model).name
gguf_model_repo = Path(model).parent
+ elif is_remote_gguf(model):
+ repo_id, _ = split_remote_gguf(model)
+ gguf_model_repo = Path(repo_id)
else:
gguf_model_repo = None
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
@@ -660,10 +687,18 @@ def get_config(
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
- is_gguf = check_gguf_file(model)
- if is_gguf:
- kwargs["gguf_file"] = Path(model).name
- model = Path(model).parent
+ _is_gguf = is_gguf(model)
+ _is_remote_gguf = is_remote_gguf(model)
+ if _is_gguf:
+ if check_gguf_file(model):
+ # Local GGUF file
+ kwargs["gguf_file"] = Path(model).name
+ model = Path(model).parent
+ elif _is_remote_gguf:
+ # Remote GGUF - extract repo_id from repo_id:quant_type format
+ # The actual GGUF file will be downloaded later by GGUFModelLoader
+ # Keep model as repo_id:quant_type for download, but use repo_id for config
+ model, _ = split_remote_gguf(model)
if config_format == "auto":
try:
@@ -671,10 +706,25 @@ def get_config(
# Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral"
- elif is_gguf or file_or_path_exists(
+ elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
config_format = "hf"
+ # Remote GGUF models must have config.json in repo,
+ # otherwise the config can't be parsed correctly.
+ # FIXME(Isotr0py): Support remote GGUF repos without config.json
+ elif _is_remote_gguf and not file_or_path_exists(
+ model, HF_CONFIG_NAME, revision=revision
+ ):
+ err_msg = (
+ "Could not find config.json for remote GGUF model repo. "
+ "To load remote GGUF model through `:`, "
+ "ensure your model has config.json (HF format) file. "
+ "Otherwise please specify --hf-config-path "
+ "in engine args to fetch config from unquantized hf model."
+ )
+ logger.error(err_msg)
+ raise ValueError(err_msg)
else:
raise ValueError(
"Could not detect config format for no config file found. "
@@ -695,9 +745,6 @@ def get_config(
"'config.json'.\n"
" - For Mistral models: ensure the presence of a "
"'params.json'.\n"
- "3. For GGUF: pass the local path of the GGUF checkpoint.\n"
- " Loading GGUF from a remote repo directly is not yet "
- "supported.\n"
).format(model=model)
raise ValueError(error_message) from e
@@ -711,7 +758,7 @@ def get_config(
**kwargs,
)
# Special architecture mapping check for GGUF models
- if is_gguf:
+ if _is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
@@ -871,6 +918,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
A dictionary containing the pooling type and whether
normalization is used, or None if no pooling configuration is found.
"""
+ if is_remote_gguf(model):
+ model, _ = split_remote_gguf(model)
modules_file_name = "modules.json"
@@ -1090,6 +1139,8 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if check_gguf_file(model):
model = Path(model).parent
+ elif is_remote_gguf(model):
+ model, _ = split_remote_gguf(model)
return get_image_processor_config(
model, token=hf_token, revision=revision, **kwargs
)
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index d28fd8d033373..109f2b6986514 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -23,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLTextConfig,
+ HunYuanVLVisionConfig,
+)
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
@@ -53,6 +58,9 @@ __all__ = [
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
+ "HunYuanVLConfig",
+ "HunYuanVLTextConfig",
+ "HunYuanVLVisionConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",
diff --git a/vllm/transformers_utils/configs/hunyuan_vl.py b/vllm/transformers_utils/configs/hunyuan_vl.py
new file mode 100644
index 0000000000000..a826ed9b5155d
--- /dev/null
+++ b/vllm/transformers_utils/configs/hunyuan_vl.py
@@ -0,0 +1,322 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
+
+from transformers import PretrainedConfig
+
+
+class HunYuanVLVisionConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_act="gelu",
+ hidden_size=1152,
+ intermediate_size=4304,
+ interpolate_mode="bilinear",
+ rms_norm_eps=1e-05,
+ learnable_mlp_pooling_size=0,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_channels=3,
+ num_hidden_layers=27,
+ out_hidden_size=4096,
+ patch_size=16,
+ remove_prenorm=True,
+ spatial_merge_size=2,
+ temporal_patch_size=1,
+ resize_resolution=2048,
+ img_max_token_num=4096,
+ max_image_size=2048,
+ video_max_image_size=768,
+ video_min_image_size=256,
+ min_image_size=512,
+ anyres_vit_max_image_size=2048,
+ max_vit_seq_len=16384,
+ text_hidden_size=3072,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_act = hidden_act
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.interpolate_mode = interpolate_mode
+ self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
+ self.num_attention_heads = num_attention_heads
+ if not num_key_value_heads:
+ self.num_key_value_heads = num_attention_heads
+ else:
+ self.num_key_value_heads = num_key_value_heads
+ self.num_channels = num_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.out_hidden_size = out_hidden_size
+ self.patch_size = patch_size
+ self.remove_prenorm = remove_prenorm
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.rms_norm_eps = rms_norm_eps
+
+ self.resize_resolution = resize_resolution
+ self.img_max_token_num = img_max_token_num
+ self.max_image_size = max_image_size
+ self.min_image_size = min_image_size
+ self.video_max_image_size = video_max_image_size
+ self.video_min_image_size = video_min_image_size
+ self.anyres_vit_max_image_size = anyres_vit_max_image_size
+ self.max_vit_seq_len = max_vit_seq_len
+ self.text_hidden_size = text_hidden_size
+
+
+class HunYuanVLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
+ Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 290943):
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HunYuanVLTextConfig`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations or shared MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ eod_token_id (int, *optional*, defaults to 3):
+ Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
+ Example: In multi-document processing, this token helps the model distinguish between separate documents.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ """ # noqa: E501
+
+ model_type = "hunyuan_vl_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=290943,
+ hidden_size=4096,
+ intermediate_size: int = 11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ eod_token_id=3,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # self._rope_scaling_validation() # TODO: Need validation?
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and "
+ f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ "`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
+ f"got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
+ raise ValueError(
+ "`rope_scaling`'s factor or alpha field must be have one, "
+ "got both of none"
+ )
+ if rope_scaling_factor is not None and (
+ not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s factor field must be a float > 1.0, "
+ f"got {rope_scaling_factor}"
+ )
+ if rope_scaling_alpha is not None and (
+ not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s alpha field must be a float > 1.0, "
+ f"got {rope_scaling_alpha}"
+ )
+
+
+class HunYuanVLConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ sub_configs = {
+ "vision_config": HunYuanVLVisionConfig,
+ "text_config": HunYuanVLTextConfig,
+ }
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ im_start_id=120118,
+ im_end_id=120119,
+ image_token_id=120120,
+ im_newline_id=120121,
+ video_start_id=120122,
+ video_end_id=120123,
+ **kwargs,
+ ):
+ # We need to init super() here so that it does not reset values
+ # that are in text config to the BaseClass defaults. The Base
+ # config has many text related defaults and not all defaults are
+ # same as for `HunYuanVLTextConfig`.
+ super().__init__(**kwargs)
+
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ # For BC use all kwargs to init `TextConfig`
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.im_start_id = im_start_id
+ self.im_end_id = im_end_id
+ self.im_newline_id = im_newline_id
+ self.video_start_id = video_start_id
+ self.video_end_id = video_end_id
+
+ self.vision_config.text_hidden_size = self.text_config.hidden_size
+
+ # Attention implementation to use. It sets it recursively on sub-configs
+ # so we call it again in the end.
+ self._attn_implementation = kwargs.pop("attn_implementation", None)
+
+ def __setattr__(self, key, value):
+ if (
+ (text_config := super().__getattribute__("__dict__").get("text_config"))
+ is not None
+ and key not in ["dtype", "_attn_implementation_internal"]
+ and key in text_config.__dict__
+ ):
+ setattr(text_config, key, value)
+ else:
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if "text_config" in super().__getattribute__("__dict__") and key not in [
+ "_name_or_path",
+ "model_type",
+ "dtype",
+ "_attn_implementation_internal",
+ ]:
+ text_config = super().__getattribute__("text_config")
+ if key in text_config.__dict__:
+ return getattr(text_config, key)
+
+ return super().__getattribute__(key)
diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py
index 2bf59c91a3bb1..f727b1b4726bb 100644
--- a/vllm/transformers_utils/gguf_utils.py
+++ b/vllm/transformers_utils/gguf_utils.py
@@ -9,6 +9,7 @@ from gguf.constants import Keys, VisionProjectorType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger
+from vllm.transformers_utils.config import list_filtered_repo_files
logger = init_logger(__name__)
@@ -164,3 +165,44 @@ def maybe_patch_hf_config_from_gguf(
hf_config = new_hf_config
return hf_config
+
+
+def get_gguf_file_path_from_hf(
+ repo_id: str | Path,
+ quant_type: str,
+ revision: str | None = None,
+) -> str:
+ """Get the GGUF file path from HuggingFace Hub based on repo_id and quant_type.
+
+ Args:
+ repo_id: The HuggingFace repository ID (e.g., "Qwen/Qwen3-0.6B")
+ quant_type: The quantization type (e.g., "Q4_K_M", "F16")
+ revision: Optional revision/branch name
+
+ Returns:
+ The path to the GGUF file on HuggingFace Hub (e.g., "filename.gguf"),
+ """
+ repo_id = str(repo_id)
+ gguf_patterns = [
+ f"*-{quant_type}.gguf",
+ f"*-{quant_type}-*.gguf",
+ f"*/*-{quant_type}.gguf",
+ f"*/*-{quant_type}-*.gguf",
+ ]
+ matching_files = list_filtered_repo_files(
+ repo_id,
+ allow_patterns=gguf_patterns,
+ revision=revision,
+ )
+
+ if len(matching_files) == 0:
+ raise ValueError(
+ "Could not find GGUF file for repo %s with quantization %s.",
+ repo_id,
+ quant_type,
+ )
+
+ # Sort to ensure consistent ordering (prefer non-sharded files)
+ matching_files.sort(key=lambda x: (x.count("-"), x))
+ gguf_filename = matching_files[0]
+ return gguf_filename
diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py
index 8deacb5b07913..63cdf63370342 100644
--- a/vllm/transformers_utils/processor.py
+++ b/vllm/transformers_utils/processor.py
@@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
-from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
+from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
@@ -236,8 +236,8 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
- if check_gguf_file(model_config.model):
- assert not check_gguf_file(model_config.tokenizer), (
+ if is_gguf(model_config.model):
+ assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor."
)
@@ -350,8 +350,8 @@ def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
- if check_gguf_file(model_config.model):
- assert not check_gguf_file(model_config.tokenizer), (
+ if is_gguf(model_config.model):
+ assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor."
)
diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py
index 76b6d3dc9c99a..b49fdbe9ce776 100644
--- a/vllm/transformers_utils/processors/__init__.py
+++ b/vllm/transformers_utils/processors/__init__.py
@@ -9,7 +9,15 @@ reasons:
"""
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
-__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
+__all__ = [
+ "DeepseekVLV2Processor",
+ "HunYuanVLProcessor",
+ "HunYuanVLImageProcessor",
+ "OvisProcessor",
+ "Ovis2_5Processor",
+]
diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py
new file mode 100644
index 0000000000000..615a8bff85912
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl.py
@@ -0,0 +1,233 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
+
+import numpy as np
+import torch
+from transformers import AutoProcessor
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.video_utils import VideoInput
+
+
+class HunYuanVLProcessor(ProcessorMixin):
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ video_processor=None,
+ chat_template=None,
+ **kwargs,
+ ):
+ # TODO Fix the init
+ self.tokenizer = tokenizer
+ self.image_token_id = 120120 # self.tokenizer.image_token_id
+ self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
+ self.im_start_token_id = 120118 # self.tokenizer.im_start_id
+ self.im_start_token = self.tokenizer.convert_ids_to_tokens(
+ self.im_start_token_id
+ )
+ self.im_end_token_id = 120119 # self.tokenizer.im_end_id
+ self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
+ self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
+ self.tokenizer.vocab_size - 1
+ )
+ self.pad_id = 120002 # self.tokenizer.pad_token_id
+
+ super().__init__(
+ image_processor, tokenizer, video_processor, chat_template=chat_template
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: TextInput
+ | PreTokenizedInput
+ | list[TextInput]
+ | list[PreTokenizedInput] = None,
+ videos: VideoInput = None,
+ **kwargs,
+ ) -> BatchFeature:
+ image_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images=images)
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+
+ image_tokens_cumsum = [0]
+ if images is not None:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ grid_h, grid_w = image_grid_thw[index][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ num_image_tokens = patch_h * (patch_w + 1) + 2
+ image_tokens_cumsum.append(
+ image_tokens_cumsum[-1] + num_image_tokens
+ )
+ # text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
+ text[i] = text[i].replace(
+ self.image_token, self.placeholder_token * num_image_tokens, 1
+ )
+ index += 1
+ text[i] = text[i].replace(self.placeholder_token, self.image_token)
+ # text[i] = self.tokenizer.bos_token + text[i]
+
+ text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ input_ids = text_inputs["input_ids"]
+ position_ids = torch.arange(len(input_ids[0]))
+ position_ids_w = torch.arange(len(input_ids[0]))
+ position_ids_h = torch.arange(len(input_ids[0]))
+ position_ids_t = torch.arange(len(input_ids[0]))
+
+ if images is not None:
+ image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
+ 0
+ ]
+ for i in range(len(image_grid_thw)):
+ grid_h, grid_w = image_grid_thw[i][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
+ replace_num = (patch_w + 1) * patch_h
+ position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
+ list(range(patch_w + 1)) * patch_h, dtype=torch.int64
+ )
+ patch_h_list = []
+ for h in range(patch_h):
+ patch_h_list += [h] * (patch_w + 1)
+ position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
+ patch_h_list, dtype=torch.int64
+ )
+ position_ids_t[start_pos : start_pos + replace_num] = 0
+
+ position_ids = torch.stack(
+ [position_ids, position_ids_w, position_ids_h, position_ids_t]
+ ).unsqueeze(0)
+ text_inputs["position_ids"] = position_ids
+
+ attention_mask = input_ids.ne(self.pad_id)
+ text_inputs["attention_mask"] = attention_mask
+ text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
+ # image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
+
+ return_tensors = kwargs.pop("return_tensors", None)
+ return BatchFeature(
+ data={**text_inputs, **image_inputs},
+ tensor_type=return_tensors,
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_image_text_to_text(
+ self,
+ generated_outputs,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ assert 0
+
+ def apply_chat_template(self, *args, **kwargs):
+ token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
+ return token_ids
+
+ def get_imgs_pos(self, doc_ids):
+ doc_ids = np.array(doc_ids, dtype=np.int64)
+ img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
+ img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
+ imgs_pos = np.concatenate(
+ (
+ np.reshape(img_begin_index + 1, (-1, 1)),
+ np.reshape(img_end_index, (-1, 1)),
+ ),
+ axis=-1,
+ ).tolist()
+ return imgs_pos
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+def split_image_into_patch_blocks(
+ pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
+ patch_size: int = 16, # e.g. 16
+ adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
+) -> torch.Tensor:
+ """
+ Split the input image tensor (supporting batch) into large patches of size `patch_size`,
+ and then further divide each large patch into smaller regions of size
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
+ Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
+ The final output contains all such small region tensors.
+
+ Args:
+ pixel_values: Input image tensor of shape [batch_size, 3, H, W].
+ patch_size: Size of the large patch, e.g., 16.
+ adaptor_patch_div: Each large patch is divided into
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
+ smaller regions.
+
+ Returns:
+ patches: A tensor of shape [N, 3, patch_size, patch_size],
+ where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
+ Each element in the batch corresponds to one small image region.
+ """ # noqa: E501
+ batch_size, channels, height, width = pixel_values.shape
+ assert channels == 3, "Pixel values must have 3 channels in dim=1"
+ assert height % patch_size == 0 and width % patch_size == 0, (
+ "H and W must be divisible by patch_size"
+ )
+
+ patch_height_num = height // patch_size
+ patch_width_num = width // patch_size
+
+ # Reshape to [B, 3, ph, ps, pw, ps]
+ img = pixel_values.reshape(
+ batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
+ )
+
+ # Further split each psxps patch into (ps//aps)x(ps//aps) small regions
+ img = img.reshape(
+ batch_size,
+ 3,
+ patch_height_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ patch_width_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ )
+
+ # Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
+ img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
+
+ # Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
+ patches = img.reshape(-1, 3, patch_size, patch_size)
+
+ return patches
+
+
+AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)
diff --git a/vllm/transformers_utils/processors/hunyuan_vl_image.py b/vllm/transformers_utils/processors/hunyuan_vl_image.py
new file mode 100644
index 0000000000000..0b10ae249dbb6
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl_image.py
@@ -0,0 +1,477 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
+"""Image processor class for HunYuanVL."""
+
+# isort conflicts with ruff for transformers imports
+# isort: skip_file
+import math
+
+import numpy as np
+import torchvision.transforms as transforms
+from transformers import AutoImageProcessor
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import (
+ convert_to_rgb,
+)
+from transformers.image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ make_flat_list_of_images,
+ make_list_of_images,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from transformers.utils import TensorType, logging
+from transformers.video_utils import VideoInput, make_batched_videos
+
+logger = logging.get_logger(__name__)
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ "absolute aspect ratio must be smaller than 200, got "
+ f"{max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class HunYuanVLImageProcessor(BaseImageProcessor):
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: int | float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and (
+ "shortest_edge" not in size or "longest_edge" not in size
+ ):
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ else:
+ size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ # hard-code
+
+ def _preprocess(
+ self,
+ images: ImageInput | VideoInput,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ do_convert_rgb: bool | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """ # noqa: E501
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ width, height = images[0].width, images[0].height
+ resized_width, resized_height = width, height
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=height,
+ width=width,
+ factor=patch_size * merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ if do_normalize:
+ image = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(self.image_mean, self.image_std),
+ ]
+ )(image)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ patches = patches.reshape(
+ 1,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
+ flatten_patches = patches.reshape(
+ 1 * grid_h * grid_w, channel * patch_size * patch_size
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int | None = None,
+ temporal_patch_size: int | None = None,
+ merge_size: int | None = None,
+ do_convert_rgb: bool | None = None,
+ return_tensors: str | TensorType | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ videos (`VideoInput`):
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """ # noqa: E501
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ min_pixels = size["shortest_edge"]
+ elif min_pixels is not None and max_pixels is not None:
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ else:
+ size = {**self.size}
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = (
+ rescale_factor if rescale_factor is not None else self.rescale_factor
+ )
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = (
+ temporal_patch_size
+ if temporal_patch_size is not None
+ else self.temporal_patch_size
+ )
+ merge_size = merge_size if merge_size is not None else self.merge_size
+ do_convert_rgb = (
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ )
+
+ if images is not None:
+ images = make_flat_list_of_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ data = {}
+ if images is not None:
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data.update(
+ {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
+ )
+
+ # kept for BC only and should be removed after v5.0
+ if videos is not None:
+ logger.warning(
+ "`HunYuanVLV1ImageProcessor` works only with image inputs "
+ "and doesn't process videos anymore. "
+ "This is a deprecated behavior and will be removed in v5.0. "
+ "Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
+ )
+ videos = make_batched_videos(videos)
+ pixel_values_videos, vision_grid_thws_videos = [], []
+ for images in videos:
+ patches, video_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values_videos.extend(patches)
+ vision_grid_thws_videos.append(video_grid_thw)
+ data.update(
+ {
+ "pixel_values_videos": np.array(pixel_values_videos),
+ "video_grid_thw": np.array(vision_grid_thws_videos),
+ }
+ )
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*):
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = (
+ images_kwargs["min_pixels"]
+ if "min_pixels" in images_kwargs
+ else self.size["shortest_edge"]
+ )
+ max_pixels = (
+ images_kwargs["max_pixels"]
+ if "max_pixels" in images_kwargs
+ else self.size["longest_edge"]
+ )
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * (grid_w + 1) + 2
+
+
+AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index 233076741503d..929dc8bf481cb 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -19,8 +19,14 @@ from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config,
list_filtered_repo_files,
)
+from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf
from vllm.transformers_utils.tokenizers import MistralTokenizer
-from vllm.transformers_utils.utils import check_gguf_file
+from vllm.transformers_utils.utils import (
+ check_gguf_file,
+ is_gguf,
+ is_remote_gguf,
+ split_remote_gguf,
+)
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -180,10 +186,19 @@ def get_tokenizer(
kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models
- is_gguf = check_gguf_file(tokenizer_name)
- if is_gguf:
- kwargs["gguf_file"] = Path(tokenizer_name).name
- tokenizer_name = Path(tokenizer_name).parent
+ if is_gguf(tokenizer_name):
+ if check_gguf_file(tokenizer_name):
+ kwargs["gguf_file"] = Path(tokenizer_name).name
+ tokenizer_name = Path(tokenizer_name).parent
+ elif is_remote_gguf(tokenizer_name):
+ tokenizer_name, quant_type = split_remote_gguf(tokenizer_name)
+ # Get the HuggingFace Hub path for the GGUF file
+ gguf_file = get_gguf_file_path_from_hf(
+ tokenizer_name,
+ quant_type,
+ revision=revision,
+ )
+ kwargs["gguf_file"] = gguf_file
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
# first to use official Mistral tokenizer if possible.
diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py
index 901a64d9d2633..45a873c9f7001 100644
--- a/vllm/transformers_utils/utils.py
+++ b/vllm/transformers_utils/utils.py
@@ -9,6 +9,8 @@ from os import PathLike
from pathlib import Path
from typing import Any
+from gguf import GGMLQuantizationType
+
import vllm.envs as envs
from vllm.logger import init_logger
@@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
return False
+@cache
+def is_remote_gguf(model: str | Path) -> bool:
+ """Check if the model is a remote GGUF model."""
+ model = str(model)
+ return (
+ (not is_cloud_storage(model))
+ and (not model.startswith(("http://", "https://")))
+ and ("/" in model and ":" in model)
+ and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
+ )
+
+
+def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
+ """Check if the quant type is a valid GGUF quant type."""
+ return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
+
+
+def split_remote_gguf(model: str | Path) -> tuple[str, str]:
+ """Split the model into repo_id and quant type."""
+ model = str(model)
+ if is_remote_gguf(model):
+ parts = model.rsplit(":", 1)
+ return (parts[0], parts[1])
+ raise ValueError(
+ "Wrong GGUF model or invalid GGUF quant type: %s.\n"
+ "- It should be in repo_id:quant_type format.\n"
+ "- Valid GGMLQuantizationType values: %s",
+ model,
+ GGMLQuantizationType._member_names_,
+ )
+
+
+def is_gguf(model: str | Path) -> bool:
+ """Check if the model is a GGUF model.
+
+ Args:
+ model: Model name, path, or Path object to check.
+
+ Returns:
+ True if the model is a GGUF model, False otherwise.
+ """
+ model = str(model)
+
+ # Check if it's a local GGUF file
+ if check_gguf_file(model):
+ return True
+
+ # Check if it's a remote GGUF model (repo_id:quant_type format)
+ return is_remote_gguf(model)
+
+
def modelscope_list_repo_files(
repo_id: str,
revision: str | None = None,
diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py
index 3ef44e7703204..fddcc27204307 100644
--- a/vllm/utils/__init__.py
+++ b/vllm/utils/__init__.py
@@ -49,13 +49,14 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
-STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
+MASK_64_BITS = (1 << 64) - 1
+
def random_uuid() -> str:
- return str(uuid.uuid4().hex)
+ return f"{uuid.uuid4().int & MASK_64_BITS:016x}" # 16 hex chars
def length_from_prompt_token_ids_or_embeds(
diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py
index 3d105a3685b37..b68157f02f6cc 100644
--- a/vllm/utils/argparse_utils.py
+++ b/vllm/utils/argparse_utils.py
@@ -73,14 +73,6 @@ class FlexibleArgumentParser(ArgumentParser):
# Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None):
- if args is not None and "--disable-log-requests" in args:
- # Special case warning because the warning below won't trigger
- # if –-disable-log-requests because its value is default.
- logger.warning_once(
- "argument '--disable-log-requests' is deprecated and "
- "replaced with '--enable-log-requests'. This will be "
- "removed in v0.12.0."
- )
namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated:
if (
@@ -255,16 +247,16 @@ class FlexibleArgumentParser(ArgumentParser):
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
- # also handle -O= here
- mode = arg[3:] if arg[2] == "=" else arg[2:]
- processed_args.append(f"-O.mode={mode}")
+ # also handle -O= here
+ optimization_level = arg[3:] if arg[2] == "=" else arg[2:]
+ processed_args += ["--optimization-level", optimization_level]
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
- # Convert -O to -O.mode
- processed_args.append("-O.mode")
+ # Convert -O to --optimization-level
+ processed_args.append("--optimization-level")
else:
processed_args.append(arg)
@@ -302,10 +294,24 @@ class FlexibleArgumentParser(ArgumentParser):
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
+ # Track regular arguments (non-dict args) for duplicate detection
+ regular_args_seen = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
+ if processed_arg.startswith("--") and "." not in processed_arg:
+ if "=" in processed_arg:
+ arg_name = processed_arg.split("=", 1)[0]
+ else:
+ arg_name = processed_arg
+
+ if arg_name in regular_args_seen:
+ duplicates.add(arg_name)
+ else:
+ regular_args_seen.add(arg_name)
+ continue
+
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)
diff --git a/vllm/utils/hashing.py b/vllm/utils/hashing.py
index 49f4f13d115f3..edf1e9cb34e56 100644
--- a/vllm/utils/hashing.py
+++ b/vllm/utils/hashing.py
@@ -5,6 +5,7 @@ from __future__ import annotations
import hashlib
import pickle
+from _hashlib import HASH, UnsupportedDigestmodError
from collections.abc import Callable
from typing import Any
@@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
+
+
+def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH:
+ """Hash for configs, defaulting to md5 but falling back to sha256
+ in FIPS constrained environments.
+
+ Args:
+ data: bytes
+ usedforsecurity: Whether the hash is used for security purposes
+
+ Returns:
+ Hash object
+ """
+ try:
+ return hashlib.md5(data, usedforsecurity=usedforsecurity)
+ except (UnsupportedDigestmodError, ValueError):
+ return hashlib.sha256(data)
diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py
index cc872040b6c5f..a4eb8f4d4fd7d 100644
--- a/vllm/utils/system_utils.py
+++ b/vllm/utils/system_utils.py
@@ -56,6 +56,39 @@ def set_env_var(key: str, value: str) -> Iterator[None]:
os.environ[key] = old
+@contextlib.contextmanager
+def suppress_stdout():
+ """
+ Suppress stdout from C libraries at the file descriptor level.
+
+ Only suppresses stdout, not stderr, to preserve error messages.
+ Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG.
+
+ Example:
+ with suppress_stdout():
+ # C library calls that would normally print to stdout
+ torch.distributed.new_group(ranks, backend="gloo")
+ """
+ # Don't suppress if logging level is DEBUG
+ if envs.VLLM_LOGGING_LEVEL == "DEBUG":
+ yield
+ return
+
+ stdout_fd = sys.stdout.fileno()
+ stdout_dup = os.dup(stdout_fd)
+ devnull_fd = os.open(os.devnull, os.O_WRONLY)
+
+ try:
+ sys.stdout.flush()
+ os.dup2(devnull_fd, stdout_fd)
+ yield
+ finally:
+ sys.stdout.flush()
+ os.dup2(stdout_dup, stdout_fd)
+ os.close(stdout_dup)
+ os.close(devnull_fd)
+
+
# File path utilities
diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py
index 590bf91b0d057..fed7dcdf293bd 100644
--- a/vllm/v1/attention/backends/cpu_attn.py
+++ b/vllm/v1/attention/backends/cpu_attn.py
@@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
- from vllm.attention import AttentionType
-
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index a9a4af5ac1183..fb080b0b33bc0 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
- from vllm.attention import AttentionType
-
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
@@ -328,7 +326,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
- seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
@@ -401,20 +398,23 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata = None
if self.dcp_world_size > 1:
- query_kv_lens_cpu = (
- common_attn_metadata.query_start_loc_cpu[1:]
- - common_attn_metadata.query_start_loc_cpu[:-1]
- )
- dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
+ query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
+ dcp_context_kv_lens = seq_lens - query_kv_lens
- dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
- dcp_context_kv_lens_cpu,
+ dcp_context_kv_lens = get_dcp_local_seq_lens(
+ dcp_context_kv_lens,
self.dcp_world_size,
self.dcp_rank,
self.cp_kv_cache_interleave_size,
)
- dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
- max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
+ # After DCP distribution, the maximum number of tokens for any rank is
+ # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
+ # and I is cp_kv_cache_interleave_size.
+ # This eliminates GPU->CPU sync while minimizing workspace over-allocation.
+ num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
+ max_dcp_context_kv_len = (
+ (max_seq_len + num_partitions - 1) // num_partitions
+ ) * self.cp_kv_cache_interleave_size
scheduler_metadata = schedule(
batch_size=num_reqs,
@@ -431,9 +431,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
- suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
- self.device, non_blocking=True
- )
+ # Use GPU tensor directly - no CPU sync needed
+ suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index 8159f4096107f..777398bf8a20e 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -930,31 +930,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0:
pure_decode = num_prefills == 0
- # possible required padding for cudagraph replay
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decode_tokens <= self._decode_cudagraph_max_bs
)
- if use_cudagraph:
- num_input_tokens = self.vllm_config.pad_for_cudagraph(
- num_decode_tokens
- )
- # Carefully fulfill the padding region with reasonable value
- # on cpu.
- # Make sure paged_kv_indptr_cpu is not decreasing
- self.paged_kv_indptr_cpu[
- 1 + num_decodes : 1 + num_input_tokens
- ].fill_(paged_kv_indptr_cpu[-1])
- # Fill the remaining paged_kv_last_page_len_cpu with 1.
- # This is because flashinfer treats 0 as a full page
- # instead of empty.
- self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
- 1
- )
-
- else:
- num_input_tokens = num_decode_tokens
+ num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph
@@ -1527,7 +1508,7 @@ def fast_plan_decode(
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
- # Make sure we pass exactly 18 arguments for tensor core version
+ # Make sure we pass exactly 19 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
@@ -1547,6 +1528,7 @@ def fast_plan_decode(
window_left,
fixed_split_size,
disable_split_kv,
+ 0,
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py
index 7768827d26dc3..8de0a0a11471f 100644
--- a/vllm/v1/attention/backends/flex_attention.py
+++ b/vllm/v1/attention/backends/flex_attention.py
@@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
- from vllm.attention import AttentionType
-
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod
diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py
index 1900c50849eca..004baa2d09cde 100644
--- a/vllm/v1/attention/backends/linear_attn.py
+++ b/vllm/v1/attention/backends/linear_attn.py
@@ -7,6 +7,7 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
+ AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
@@ -35,6 +36,8 @@ class LinearAttentionMetadata:
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
+ _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
+
def __init__(
self,
kv_cache_spec: AttentionSpec,
diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py
index 0d875565fc99a..a9705db59f19d 100644
--- a/vllm/v1/attention/backends/mamba_attn.py
+++ b/vllm/v1/attention/backends/mamba_attn.py
@@ -107,6 +107,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
+ # -1 in the case we have a padded request (0 seq-len)
+ block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
return (
block_idx_last_computed_token,
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index 43aef8a7cca91..d94ed9183f639 100755
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -340,6 +340,8 @@ class MLACommonPrefillMetadata:
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
+ token_to_seq: torch.Tensor
+ chunk_total_token: list[int]
# for mla DCP
padded_local_chunk_seq_lens: list[list[int]] | None = None
@@ -839,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
+ chunk_total_token = cu_seq_lens_cpu[:, -1]
+
+ max_token_num_over_chunk = chunk_total_token.max().item()
+ token_to_seq_tensor_cpu = torch.zeros(
+ [num_chunks, max_token_num_over_chunk], dtype=torch.int32
+ )
+ range_idx = torch.arange(num_prefills, dtype=torch.int32)
+ for i in range(num_chunks):
+ chunk_token_to_seq_tensor = torch.repeat_interleave(
+ range_idx, chunk_seq_lens[i]
+ )
+ chunk_len = chunk_token_to_seq_tensor.shape[0]
+ token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
if self.dcp_world_size > 1:
local_context_lens_allranks = get_dcp_local_seq_lens(
@@ -906,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
+ token_to_seq=token_to_seq_tensor_cpu.to(
+ device, non_blocking=True
+ ),
+ chunk_total_token=chunk_total_token.tolist(),
workspace=self.chunked_prefill_workspace,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
@@ -922,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
+ token_to_seq=token_to_seq_tensor_cpu.to(
+ device, non_blocking=True
+ ),
+ chunk_total_token=chunk_total_token,
workspace=self.chunked_prefill_workspace,
)
@@ -1215,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
+
if self.is_aiter_triton_fp8_bmm_enabled:
+ out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = rocm_aiter_ops.triton_fp8_bmm(
- x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
- # Convert from (B, N, V) to (B, N * V)
- x = x.reshape(-1, self.num_heads * self.v_head_dim)
- # Copy result
- out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
@@ -1638,16 +1659,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
-
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
-
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
- batch_size=attn_metadata.num_prefills,
+ token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
+ num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
@@ -1802,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
- ) -> torch.Tensor:
+ output: torch.Tensor,
+ ) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
@@ -1815,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
- output = self._run_prefill_new_tokens(
+ output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q,
k=k,
@@ -1824,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
if has_context:
- suffix_output, suffix_lse = output
+ suffix_output, suffix_lse = output_prefill
if self.dcp_world_size > 1:
context_output, context_lse = (
self._context_parallel_compute_prefill_context(
@@ -1840,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
- output = torch.empty_like(suffix_output)
+ # unpad if necessary
+ if self._pad_v:
+ context_output = context_output[..., : v.shape[-1]]
+ suffix_output = suffix_output[..., : v.shape[-1]]
+
+ output = output.view(-1, self.num_heads, self.v_head_dim)
merge_attn_states(
output=output,
prefix_output=context_output,
@@ -1848,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
-
- # unpad if necessary
- if self._pad_v:
- output = output[..., : v.shape[-1]]
-
- return output.flatten(start_dim=-2)
+ else:
+ output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
+ output.copy_(output_prefill)
@abstractmethod
def _forward_decode(
@@ -1948,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
- output[num_decode_tokens:] = self._forward_prefill(
+ self._forward_prefill(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
layer._k_scale,
+ output=output[num_decode_tokens:],
)
if has_decode:
diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
index 56f9c7a281e7f..00a0a77a1c2f7 100644
--- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
@@ -49,6 +49,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
paged_kv_last_page_len: torch.Tensor | None = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None
+ # The dtype of MLA out tensor
+ attn_out_dtype: torch.dtype = torch.bfloat16
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@@ -74,6 +76,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
)
self.compilation_config = vllm_config.compilation_config
+ self.decode_attn_out_dtype = vllm_config.model_config.dtype
# kernel block size is always 1.
max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
@@ -162,6 +165,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
+ attn_out_dtype=self.decode_attn_out_dtype,
)
return attn_metadata
@@ -242,7 +246,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(
- B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
+ B,
+ self.num_heads,
+ self.kv_lora_rank,
+ dtype=attn_metadata.decode.attn_out_dtype,
+ device=q.device,
)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
@@ -260,6 +268,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
+ q_scale=layer._q_scale,
+ kv_scale=layer._k_scale,
)
return o, None
diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py
index 540a8e2b1d016..27f07218d9b2e 100644
--- a/vllm/v1/attention/backends/utils.py
+++ b/vllm/v1/attention/backends/utils.py
@@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
-from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
+from vllm.attention.backends.abstract import (
+ AttentionBackend,
+ AttentionImpl,
+ AttentionMetadata,
+)
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout,
)
@@ -72,6 +75,7 @@ class CommonAttentionMetadata:
num_reqs: int
"""Number of requests"""
+ # TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
@@ -89,12 +93,39 @@ class CommonAttentionMetadata:
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
- encoder_seq_lens: np.ndarray | None = None
+ encoder_seq_lens: torch.Tensor | None = None
+ encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
+ # TODO(lucas): remove once we have FULL-CG spec-decode support
+ def unpadded(
+ self, num_actual_tokens: int, num_actual_reqs: int
+ ) -> "CommonAttentionMetadata":
+ maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
+ return CommonAttentionMetadata(
+ query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
+ query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
+ seq_lens=self.seq_lens[:num_actual_reqs],
+ seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
+ num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
+ num_reqs=num_actual_reqs,
+ num_actual_tokens=num_actual_tokens,
+ max_query_len=self.max_query_len,
+ max_seq_len=self.max_seq_len,
+ block_table_tensor=self.block_table_tensor[:num_actual_reqs],
+ slot_mapping=self.slot_mapping[:num_actual_tokens],
+ causal=self.causal,
+ logits_indices_padded=self.logits_indices_padded,
+ num_logits_indices=self.num_logits_indices,
+ encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
+ encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
+ dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
+ dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
+ )
+
def slice_query_start_locs(
query_start_loc: torch.Tensor,
@@ -856,7 +887,9 @@ def split_decodes_and_prefills(
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
- is_prefill = query_lens > decode_threshold
+ # 0-query len indicates a padded request; leave this at the back
+ # of the batch with the prefills
+ is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
@@ -1091,12 +1124,14 @@ def get_dcp_local_seq_lens(
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
- torch.arange(dcp_size, dtype=torch.int32)
+ torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
.unsqueeze(0)
.repeat(num_requests, 1)
)
else:
- rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
+ rank_offsets = torch.tensor(
+ [[dcp_rank]], dtype=torch.int32, device=seq_lens.device
+ )
seq_lens_tiled = (
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
)
diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py
deleted file mode 100644
index 5039c44b9c3e6..0000000000000
--- a/vllm/v1/attention/backends/xformers.py
+++ /dev/null
@@ -1,420 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Attention layer with XFormersAttention."""
-
-from dataclasses import dataclass
-from typing import ClassVar, Optional
-
-import torch
-
-from vllm.attention.backends.abstract import (
- AttentionBackend,
- AttentionImpl,
- AttentionType,
- MultipleOf,
-)
-from vllm.attention.ops.triton_unified_attention import unified_attention
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.v1.attention.backends.utils import (
- AttentionMetadataBuilder,
- CommonAttentionMetadata,
- split_decodes_and_prefills,
-)
-from vllm.v1.kv_cache_interface import AttentionSpec
-
-try:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import (
- AttentionBias,
- PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
- )
-
- XFORMERS_AVAILABLE = True
-except ImportError:
- XFORMERS_AVAILABLE = False
-
-from vllm import _custom_ops as ops
-
-logger = init_logger(__name__)
-
-
-class XFormersAttentionBackend(AttentionBackend):
- accept_output_buffer: bool = True
- supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
-
- @staticmethod
- def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
- return [MultipleOf(16)]
-
- @classmethod
- def get_supported_head_sizes(cls) -> list[int]:
- return [
- 32,
- 40,
- 48,
- 56,
- 64,
- 72,
- 80,
- 88,
- 96,
- 104,
- 112,
- 120,
- 128,
- 136,
- 144,
- 152,
- 160,
- 168,
- 176,
- 184,
- 192,
- 200,
- 208,
- 216,
- 224,
- 232,
- 240,
- 248,
- 256,
- ]
-
- @staticmethod
- def get_name() -> str:
- return "XFORMERS"
-
- @staticmethod
- def get_impl_cls() -> type["XFormersAttentionImpl"]:
- return XFormersAttentionImpl
-
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (2, num_blocks, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
- return XFormersAttentionMetadataBuilder
-
- @staticmethod
- def use_cascade_attention(*args, **kwargs) -> bool:
- return False
-
-
-@dataclass
-class XFormersAttentionMetadata:
- num_actual_tokens: int # Number of tokens excluding padding.
- max_query_len: int
- query_start_loc: torch.Tensor
- max_seq_len: int
- seq_lens: torch.Tensor
- block_table: torch.Tensor
- slot_mapping: torch.Tensor
-
- num_prefill_tokens: int = 0
- num_decode_tokens: int = 0
- num_prefills: int = 0
- num_decodes: int = 0
-
- # Biases for different attention types.
- attn_bias: Optional["AttentionBias"] = None
-
- # Self-attention prefill/decode metadata cache
- _cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
- _cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
-
- @property
- def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
- if self.num_prefills == 0:
- return None
-
- if self._cached_prefill_metadata is not None:
- # Recover cached prefill-phase attention
- # metadata structure
- return self._cached_prefill_metadata
-
- q_start_loc = self.query_start_loc[self.num_decodes :]
- q_seqlens = torch.diff(q_start_loc)
- kv_seqlens = self.seq_lens[self.num_decodes :]
- # Construct & cache prefill-phase attention metadata structure
- self._cached_prefill_metadata = XFormersAttentionMetadata(
- num_actual_tokens=self.num_prefill_tokens,
- max_query_len=int(q_seqlens.max().item()),
- query_start_loc=q_start_loc - q_start_loc[0],
- max_seq_len=int(kv_seqlens.max().item()),
- seq_lens=kv_seqlens,
- block_table=self.block_table[self.num_decodes :],
- slot_mapping=self.slot_mapping[self.num_decode_tokens :],
- )
- return self._cached_prefill_metadata
-
- @property
- def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
- if self.num_decode_tokens == 0:
- return None
-
- if self._cached_decode_metadata is not None:
- # Recover cached decode-phase attention
- # metadata structure
- return self._cached_decode_metadata
-
- q_start_loc = self.query_start_loc
- q_seqlens = torch.diff(q_start_loc)
- decode_kv_seqlens = self.seq_lens[: self.num_decodes]
- # Construct & cache decode-phase attention metadata structure
- self._cached_decode_metadata = XFormersAttentionMetadata(
- num_actual_tokens=self.num_decode_tokens,
- max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
- query_start_loc=q_start_loc[: self.num_decodes + 1],
- max_seq_len=int(decode_kv_seqlens.max().item()),
- seq_lens=decode_kv_seqlens,
- block_table=self.block_table[: self.num_decodes],
- slot_mapping=self.slot_mapping[: self.num_decode_tokens],
- attn_bias=self.attn_bias,
- )
- return self._cached_decode_metadata
-
-
-class XFormersAttentionMetadataBuilder(
- AttentionMetadataBuilder[XFormersAttentionMetadata]
-):
- reorder_batch_threshold: int = 1
-
- def __init__(
- self,
- kv_cache_spec: AttentionSpec,
- layer_names: list[str],
- vllm_config: VllmConfig,
- device: torch.device,
- ):
- super().__init__(kv_cache_spec, layer_names, vllm_config, device)
-
- assert XFORMERS_AVAILABLE
- self.block_size = kv_cache_spec.block_size
- self._num_decodes = 0
- self._num_decode_tokens = 0
-
- def build(
- self,
- common_prefix_len: int,
- common_attn_metadata: CommonAttentionMetadata,
- fast_build: bool = False,
- ) -> XFormersAttentionMetadata:
- num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
- split_decodes_and_prefills(
- common_attn_metadata, decode_threshold=self.reorder_batch_threshold
- )
- )
-
- num_actual_tokens = common_attn_metadata.num_actual_tokens
- q_start_loc = common_attn_metadata.query_start_loc
- q_seqlens = torch.diff(q_start_loc)
- max_query_len = common_attn_metadata.max_query_len
- kv_seqlens = common_attn_metadata.seq_lens
- max_seq_len = common_attn_metadata.max_seq_len
- block_table = common_attn_metadata.block_table_tensor
- slot_mapping = common_attn_metadata.slot_mapping
-
- bias = None
- if num_decodes > 0:
- # Construct the decoder bias.
- decode_q_seqlens = q_seqlens[:num_decodes]
- decode_kv_seqlens = kv_seqlens[:num_decodes]
- bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
- q_seqlen=decode_q_seqlens.tolist(),
- kv_seqlen=decode_kv_seqlens.tolist(),
- page_size=self.block_size,
- block_tables=block_table[:num_decodes],
- device=block_table.device,
- )
-
- return XFormersAttentionMetadata(
- num_actual_tokens=num_actual_tokens,
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- num_prefills=num_prefills,
- num_decodes=num_decodes,
- max_query_len=max_query_len,
- query_start_loc=q_start_loc,
- max_seq_len=max_seq_len,
- seq_lens=kv_seqlens,
- block_table=block_table,
- slot_mapping=slot_mapping,
- attn_bias=bias,
- )
-
-
-class XFormersAttentionImpl(AttentionImpl):
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: int,
- alibi_slopes: list[float] | None,
- sliding_window: int | None,
- kv_cache_dtype: str,
- logits_soft_cap: float | None = None,
- attn_type: AttentionType = AttentionType.DECODER,
- kv_sharing_target_layer_name: str | None = None,
- ) -> None:
- if kv_sharing_target_layer_name is not None:
- raise NotImplementedError("KV sharing is not supported in V0.")
- if alibi_slopes is not None:
- raise NotImplementedError("XFormers does not support alibi slopes yet.")
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_kv_heads
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- self.kv_cache_dtype = kv_cache_dtype
- self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
- if alibi_slopes is not None:
- alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
- self.alibi_slopes = alibi_slopes
- if sliding_window is None:
- self.sliding_window = (-1, -1)
- else:
- self.sliding_window = (sliding_window - 1, 0)
- if logits_soft_cap is None:
- # Setting logits_soft_cap to 0 means no soft cap.
- logits_soft_cap = 0
- self.logits_soft_cap = logits_soft_cap
-
- if attn_type != AttentionType.DECODER:
- raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "XFormersAttentionImpl."
- )
-
- def forward(
- self,
- layer: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: XFormersAttentionMetadata,
- output: torch.Tensor | None = None,
- output_scale: torch.Tensor | None = None,
- output_block_scale: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Forward pass with XFormers.
-
- Args:
- query: shape = [num_tokens, num_heads, head_size]
- key: shape = [num_tokens, num_kv_heads, head_size]
- value: shape = [num_tokens, num_kv_heads, head_size]
- kv_cache: shape =
- [2, num_blocks, block_size, num_kv_heads, head_size]
- attn_metadata: Metadata for attention.
- Returns:
- shape = [num_tokens, num_heads * head_size]
- """
- assert output is not None, "Output tensor must be provided."
-
- if output_scale is not None or output_block_scale is not None:
- raise NotImplementedError(
- "fused output quantization is not yet supported"
- " for XFormersAttentionImpl"
- )
-
- if attn_metadata is None:
- # Profiling run.
- return output.fill_(0)
-
- # Cache the input KVs.
- key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
- # Reshape the input keys and values and store them in the cache.
- # Skip this if sharing KV cache with an earlier attention layer.
- # NOTE(woosuk): Here, key and value are padded while slot_mapping is
- # not padded. However, we don't need to do key[:num_actual_tokens]
- # and value[:num_actual_tokens] because the reshape_and_cache_flash
- # op uses the slot_mapping's shape to determine the number of
- # actual tokens.
- ops.reshape_and_cache_flash(
- key,
- value,
- key_cache,
- value_cache,
- attn_metadata.slot_mapping,
- self.kv_cache_dtype,
- layer._k_scale,
- layer._v_scale,
- )
-
- num_actual_tokens = attn_metadata.num_actual_tokens
- num_decode_tokens = attn_metadata.num_decode_tokens
- if prefill_meta := attn_metadata.prefill_metadata:
- descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
- unified_attention(
- q=query[num_decode_tokens:num_actual_tokens],
- k=key_cache,
- v=value_cache,
- out=output[num_decode_tokens:num_actual_tokens],
- cu_seqlens_q=prefill_meta.query_start_loc,
- max_seqlen_q=prefill_meta.max_query_len,
- seqused_k=prefill_meta.seq_lens,
- max_seqlen_k=prefill_meta.max_seq_len,
- softmax_scale=self.scale,
- causal=True,
- alibi_slopes=self.alibi_slopes,
- window_size=self.sliding_window,
- block_table=prefill_meta.block_table,
- softcap=self.logits_soft_cap,
- q_descale=None, # Not supported
- k_descale=layer._k_scale.expand(descale_shape),
- v_descale=layer._v_scale.expand(descale_shape),
- )
-
- if decode_meta := attn_metadata.decode_metadata:
- # Query for decode. KV is not needed because it is already cached.
- decode_query = query[:num_decode_tokens]
- # Reshape query to [1, B_T, G, H, D].
- q = decode_query.view(
- 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
- )
- # Reshape the k and v caches to [1, Bkv_T, G, H, D]
- cache_k = key_cache.view(
- 1, -1, self.num_kv_heads, 1, self.head_size
- ).expand(
- 1,
- -1,
- self.num_kv_heads,
- self.num_queries_per_kv,
- self.head_size,
- )
- cache_v = value_cache.view(
- 1, -1, self.num_kv_heads, 1, self.head_size
- ).expand(
- 1,
- -1,
- self.num_kv_heads,
- self.num_queries_per_kv,
- self.head_size,
- )
-
- attn_bias = decode_meta.attn_bias
- output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
- q,
- cache_k,
- cache_v,
- attn_bias=attn_bias,
- p=0.0,
- scale=self.scale,
- ).view(decode_query.shape)
-
- # Reshape the output tensor.
- return output
diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py
index 55710ad5cc693..8b0e8fd3a2410 100644
--- a/vllm/v1/core/block_pool.py
+++ b/vllm/v1/core/block_pool.py
@@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (
BlockHash,
+ BlockHashList,
+ BlockHashListWithBlockSize,
BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue,
@@ -133,6 +135,10 @@ class BlockPool:
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
+ hash_block_size: The block size of which the block hashes are computed.
+ The actual block size usually equals hash_block_size, but in cases
+ where different KV cache groups have different block sizes, the
+ actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events.
"""
@@ -140,11 +146,13 @@ class BlockPool:
self,
num_gpu_blocks: int,
enable_caching: bool,
+ hash_block_size: int,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
+ self.hash_block_size = hash_block_size
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
@@ -223,8 +231,20 @@ class BlockPool:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks
- new_block_hashes = request.block_hashes[num_cached_blocks:]
+ if block_size == self.hash_block_size:
+ # Common case.
+ block_hashes: BlockHashList = request.block_hashes
+ else:
+ # block_size is a multiple of hash_block_size. This happens when
+ # different KV cache groups have different block sizes.
+ assert block_size % self.hash_block_size == 0
+ # Recalculate block_hashes at the granularity of block_size, using
+ # the original block_hashes (at the granularity of hash_block_size).
+ block_hashes = BlockHashListWithBlockSize(
+ request.block_hashes, self.hash_block_size, block_size
+ )
+ new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None
)
diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py
index 1531b61f88fe2..fd1ec8e27fba2 100644
--- a/vllm/v1/core/kv_cache_coordinator.py
+++ b/vllm/v1/core/kv_cache_coordinator.py
@@ -2,15 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
+from math import lcm
from vllm.v1.core.block_pool import BlockPool
-from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
+from vllm.v1.core.kv_cache_utils import (
+ BlockHash,
+ BlockHashList,
+ BlockHashListWithBlockSize,
+ KVCacheBlock,
+)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
FullAttentionManager,
get_manager_for_kv_cache_spec,
)
-from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
+from vllm.v1.kv_cache_interface import (
+ FullAttentionSpec,
+ KVCacheConfig,
+ KVCacheSpec,
+)
from vllm.v1.request import Request
@@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.block_pool = BlockPool(
- kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
+ kv_cache_config.num_blocks,
+ enable_caching,
+ hash_block_size,
+ enable_kv_cache_events,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
@@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
+ # For models using only Mamba, block_size is set to max_model_len when
+ # prefix caching is disabled, and hash_block_size validation is skipped.
+ assert not enable_caching or (hash_block_size == self.block_size), (
+ "UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
+ )
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
+ alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
@@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
+ # hash_block_size: the block size used to compute block hashes.
+ # The actual block size usually equals hash_block_size, but in cases where
+ # different KV cache groups have different block sizes, the actual block size
+ # can be a multiple of hash_block_size.
+ self.hash_block_size = hash_block_size
+ assert all(
+ g.kv_cache_spec.block_size % hash_block_size == 0
+ for g in kv_cache_config.kv_cache_groups
+ ), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
@@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
-
- if self.enable_caching:
- # this requirement is only needed for the prefix caching logic
- divisible = self.other_block_size % self.full_attention_block_size
- assert divisible == 0, (
- "KVCacheCoordinator assumes the block_size of full "
- "attention layers is divisible by other layers now."
- )
+ # The LCM of the block sizes of full attention and other attention.
+ # The cache hit length must be a multiple of the LCM of the block sizes
+ # to make sure the cache hit length is a multiple of the block size of
+ # each attention type. Requiring this because we don't support partial
+ # block cache hit yet.
+ self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
@@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
+ if self.full_attention_spec.block_size == self.hash_block_size:
+ # Common case.
+ full_attention_block_hashes: BlockHashList = block_hashes
+ else:
+ # block_size is a multiple of hash_block_size. This happens when different
+ # KV cache groups have different block sizes. In this case, we need to
+ # recalculate block_hashes at the granularity of block_size, using the
+ # original block_hashes (at the granularity of hash_block_size).
+ full_attention_block_hashes = BlockHashListWithBlockSize(
+ block_hashes, self.hash_block_size, self.full_attention_spec.block_size
+ )
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
- block_hashes=block_hashes,
+ block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
+ alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
+ if self.other_spec.block_size == self.hash_block_size:
+ # Common case.
+ other_block_hashes: BlockHashList = block_hashes
+ else:
+ # Similar to the full attention case, here we need to recalculate
+ # block_hashes at the granularity of block_size, using the original
+ # block_hashes (at the granularity of hash_block_size).
+ other_block_hashes = BlockHashListWithBlockSize(
+ block_hashes, self.hash_block_size, self.other_spec.block_size
+ )
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
- block_hashes=block_hashes,
+ block_hashes=other_block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
+ alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
@@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
+ hash_block_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
max_model_len,
use_eagle,
enable_kv_cache_events,
- dcp_world_size=dcp_world_size,
- pcp_world_size=pcp_world_size,
+ dcp_world_size,
+ pcp_world_size,
+ hash_block_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
- dcp_world_size=dcp_world_size,
- pcp_world_size=pcp_world_size,
+ dcp_world_size,
+ pcp_world_size,
+ hash_block_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
- dcp_world_size=dcp_world_size,
- pcp_world_size=pcp_world_size,
+ dcp_world_size,
+ pcp_world_size,
+ hash_block_size,
)
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 2012c3fef88bc..b061e5cc831dd 100644
--- a/vllm/v1/core/kv_cache_manager.py
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -95,6 +95,7 @@ class KVCacheManager:
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
+ hash_block_size: int,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
@@ -107,28 +108,11 @@ class KVCacheManager:
self.enable_caching = enable_caching
self.use_eagle = use_eagle
self.log_stats = log_stats
- # FIXME: make prefix cache stats conditional on log_stats
+ # FIXME: make prefix cache stats conditional on log_stats. We still need
+ # this comment because when the log stats is enabled there are still
+ # potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
- self.block_size: int | None = None
- if self.enable_caching:
- assert (
- len(
- set(
- g.kv_cache_spec.block_size
- for g in kv_cache_config.kv_cache_groups
- )
- )
- == 1
- ), "Only one block size is supported for now"
- self.block_size = kv_cache_config.kv_cache_groups[
- 0
- ].kv_cache_spec.block_size
-
- if dcp_world_size * pcp_world_size > 1:
- assert len(kv_cache_config.kv_cache_groups) == 1
- self.block_size *= dcp_world_size * pcp_world_size
-
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
@@ -137,6 +121,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
+ hash_block_size=hash_block_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index b18ba8e8b2c7b..602eb81beb010 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -5,9 +5,9 @@
import copy
import os
from collections import defaultdict
-from collections.abc import Callable, Iterable, Sequence
-from dataclasses import dataclass
-from typing import Any, NewType, TypeAlias
+from collections.abc import Callable, Iterable, Iterator, Sequence
+from dataclasses import dataclass, replace
+from typing import Any, NewType, TypeAlias, overload
from vllm import envs
from vllm.config import VllmConfig
@@ -825,11 +825,11 @@ def get_num_blocks(
return num_blocks
-def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
+def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
- page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
+ page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1
return page_sizes.pop()
@@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
return len(page_sizes) == 1
+def unify_kv_cache_spec_page_size(
+ kv_cache_spec: dict[str, KVCacheSpec],
+) -> dict[str, KVCacheSpec]:
+ """
+ Unify the page size of the given KVCacheSpec. If the page size of all layers
+ are the same, return the original KVCacheSpec. If not same, unify the page
+ size by increasing the block size of layers with smaller page size. Raise
+ NotImplementedError if failed to unify the page size.
+
+ Args:
+ kv_cache_spec: The KVCacheSpec of each attention layer in the model
+
+ Returns:
+ The updated KVCacheSpec with the same page_size_bytes.
+ """
+ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
+ if len(page_sizes) <= 1:
+ # All layers have the same page size, no need to unify.
+ return kv_cache_spec
+
+ max_page_size = max(page_sizes)
+ new_kv_cache_spec = {}
+ for layer_name, layer_spec in kv_cache_spec.items():
+ if layer_spec.page_size_bytes == max_page_size:
+ new_kv_cache_spec[layer_name] = layer_spec
+ else:
+ layer_page_size = layer_spec.page_size_bytes
+ if max_page_size % layer_page_size != 0:
+ raise NotImplementedError(
+ "The page size of the layer is not divisible by the "
+ "maximum page size. Cannot unify by adjusting block_size."
+ )
+ ratio = max_page_size // layer_page_size
+ new_block_size = layer_spec.block_size * ratio
+ new_spec = replace(layer_spec, block_size=new_block_size)
+ assert new_spec.page_size_bytes == max_page_size
+ new_kv_cache_spec[layer_name] = new_spec
+ return new_kv_cache_spec
+
+
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
# kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec
@@ -971,7 +1011,16 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
- group_size = min([len(layers) for layers in same_type_layers.values()])
+ min_num_layers = min([len(layers) for layers in same_type_layers.values()])
+ group_size = min_num_layers
+ max_num_layers = max([len(layers) for layers in same_type_layers.values()])
+ if max_num_layers < min_num_layers * 1.25:
+ # If the number of layers is not much larger than the minimum number of layers,
+ # use the maximum number of layers as the group size to avoid too many padding
+ # layers. A typical example is gpt-oss-20b + eagle, with 12 sw + 13 full. We
+ # pad it to (13 sw, 13 full) instead of (12 sw, 24 full). 1.25 is just a
+ # magic number to avoid too many padding layers.
+ group_size = max_num_layers
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
@@ -1001,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
def get_kv_cache_config_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
- kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
@@ -1011,7 +1059,6 @@ def get_kv_cache_config_from_groups(
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
- kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
@@ -1055,7 +1102,9 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
- page_size = get_uniform_page_size(kv_cache_specs)
+ page_size = get_uniform_page_size(
+ [group.kv_cache_spec for group in kv_cache_groups]
+ )
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size
@@ -1157,7 +1206,8 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle
# attention free models.
return []
- elif is_kv_cache_spec_uniform(kv_cache_spec):
+
+ if is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
@@ -1167,14 +1217,16 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec)
- elif is_kv_cache_page_size_uniform(kv_cache_spec):
- # Model contains multiple attention types, but KV cache of all layers
- # have the same physical memory per block per layer. Split the layers
- # into groups with the same number of layers, and thus same total page
- # size.
- return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
- raise NotImplementedError
+ # As KVCacheManager can only allocate memory of one size, we need to unify
+ # the page size of the layers. For cases cannot be unified, this function
+ # will raise an error.
+ kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
+ # Model contains multiple attention types, but KV cache of all layers
+ # have the same physical memory per block per layer. Split the layers
+ # into groups with the same number of layers, and thus same total page
+ # size.
+ return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config(
@@ -1318,10 +1370,7 @@ def get_kv_cache_configs(
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append(
get_kv_cache_config_from_groups(
- vllm_config,
- kv_cache_groups_one_worker,
- kv_cache_spec_one_worker,
- available_memory_one_worker,
+ vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
)
)
@@ -1344,3 +1393,79 @@ def get_kv_cache_configs(
_report_kv_cache_config(vllm_config, kv_cache_config)
return kv_cache_configs
+
+
+class BlockHashListWithBlockSize:
+ """
+ Convert block-hash granularity from `hash_block_size` to `target_block_size`.
+ Used when KV cache groups have different block sizes: `hash_block_size`
+ is the size used to compute the original `block_hashes`; `target_block_size`
+ is the group's actual block size.
+
+ Currently, only scaling up by an integer factor is supported (i.e.,
+ `target_block_size` is a multiple of `hash_block_size`). Conversion is
+ performed lazily on access for efficiency, by concatenating consecutive
+ hashes at `hash_block_size` to form each hash at `target_block_size`.
+
+ Example (`hash_block_size` = 16, `target_block_size` = 32):
+ concatenating two 16-size hashes yields one 32-size hash:
+
+ Block hashes with block_size 16:
+ | Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
+ |-------------|------|-------|-------|-------|
+ | Hash | A | B | C | D |
+
+ Block hashes with block_size 32:
+ | Token Range | 0-31 | 32-63 |
+ |-------------|------|-------|
+ | Hash | AB | CD |
+
+ Args:
+ block_hashes: Block hashes to convert, computed at `hash_block_size`.
+ hash_block_size: Block size at which `block_hashes` were computed.
+ target_block_size: Desired block size; must be a multiple of `hash_block_size`.
+ """
+
+ def __init__(
+ self,
+ block_hashes: list[BlockHash],
+ hash_block_size: int,
+ target_block_size: int,
+ ):
+ self.block_hashes = block_hashes
+ assert target_block_size % hash_block_size == 0
+ self.scale_factor = target_block_size // hash_block_size
+
+ def __len__(self) -> int:
+ return len(self.block_hashes) // self.scale_factor
+
+ @overload
+ def __getitem__(self, idx: int) -> BlockHash: ...
+
+ @overload
+ def __getitem__(self, idx: slice) -> list[BlockHash]: ...
+
+ def __getitem__(self, idx):
+ if isinstance(idx, int):
+ return self._get_value_at(idx)
+
+ if isinstance(idx, slice):
+ start, stop, step = idx.indices(len(self))
+ return [self._get_value_at(i) for i in range(start, stop, step)]
+
+ raise TypeError(f"Invalid index type: {type(idx)!r}")
+
+ def __iter__(self) -> Iterator[BlockHash]:
+ for i in range(len(self)):
+ yield self._get_value_at(i)
+
+ def _get_value_at(self, idx: int) -> BlockHash:
+ base = idx * self.scale_factor
+ end = base + self.scale_factor
+ merged_hash: bytes = self.block_hashes[base]
+ for i in range(base + 1, end):
+ merged_hash += self.block_hashes[i]
+ return BlockHash(merged_hash)
+
+
+BlockHashList = list[BlockHash] | BlockHashListWithBlockSize
diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
index 7902513dce49a..abfab43499b2a 100644
--- a/vllm/v1/core/sched/output.py
+++ b/vllm/v1/core/sched/output.py
@@ -126,12 +126,12 @@ class CachedRequestData:
return len(self.req_ids)
@cached_property
- @deprecated("use resumed_req_ids field")
+ @deprecated("This will be removed in v0.14, use `resumed_req_ids` instead.")
def resumed_from_preemption(self) -> list[bool]:
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
@cached_property
- @deprecated("use all_token_ids field")
+ @deprecated("This will be removed in v0.14, use `all_token_ids` instead.")
def resumed_req_token_ids(self) -> list[list[int] | None]:
return [
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index a7ec0de372631..0304a8ec48bf7 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
+ hash_block_size=self.block_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
@@ -508,9 +509,9 @@ class Scheduler(SchedulerInterface):
not self.scheduler_config.enable_chunked_prefill
and num_new_tokens > token_budget
):
- self.waiting.pop_request()
- skipped_waiting_requests.prepend_request(request)
- continue
+ # If chunked_prefill is disabled,
+ # we can stop the scheduling here.
+ break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@@ -1088,8 +1089,6 @@ class Scheduler(SchedulerInterface):
and request.sampling_params.logprobs is not None
and logprobs
):
- # NOTE: once we support N tokens per step (spec decode),
- # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and self.structured_output_manager.should_advance(request):
diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
index d90ec550f7666..4aeb17a156bb3 100644
--- a/vllm/v1/core/single_type_kv_cache_manager.py
+++ b/vllm/v1/core/single_type_kv_cache_manager.py
@@ -7,7 +7,7 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
-from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
+from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
@@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
@abstractmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
+ alignment_tokens: The returned cache hit length (in tokens) should
+ be a multiple of this value (in tokens). By default, it should
+ be set to the block_size.
+ dcp_world_size: The world size of decode context parallelism.
+ pcp_world_size: The world size of prefill context parallelism.
Returns:
A list of cached blocks with skipped blocks replaced by null block
@@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
- kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
+ kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
else:
break
if use_eagle and computed_blocks[0]:
+ # Need to drop the last matched block if eagle is enabled.
+ for computed in computed_blocks:
+ computed.pop()
+ while (
+ block_size != alignment_tokens # Faster for common case.
+ and len(computed_blocks[0]) * block_size % alignment_tokens != 0
+ ):
for computed in computed_blocks:
computed.pop()
return computed_blocks
@@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))
)
+ block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
@@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
+ # Skip prefix matching check if the block is not aligned with
+ # `alignment_tokens`.
+ if (
+ num_contiguous_blocks == 0
+ and block_size != alignment_tokens # Faster for common case.
+ and (i + 1) * block_size % alignment_tokens != 0
+ ):
+ continue
+ # Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
@@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
+ while (
+ block_size != alignment_tokens # Faster for common case.
+ and len(computed_blocks[0]) * block_size % alignment_tokens != 0
+ ):
+ for computed in computed_blocks:
+ computed.pop()
if use_eagle and computed_blocks[0]:
+ assert kv_cache_spec.block_size == alignment_tokens, (
+ "aligned_length is not compatible with eagle now"
+ )
for computed in computed_blocks:
computed.pop()
return computed_blocks
@@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
+ dcp_world_size: The world size of decode context parallelism.
+ pcp_world_size: The world size of prefill context parallelism.
+ alignment_tokens: The returned cache hit length (in tokens) should
+ be a multiple of this value (in tokens).
Returns:
A list of cached blocks
@@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
+ assert kv_cache_spec.block_size == alignment_tokens, (
+ "KV cache groups with different block sizes are not compatible with "
+ "chunked local attention now"
+ )
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
- max_num_blocks = max_length // kv_cache_spec.block_size
+ block_size = kv_cache_spec.block_size
+ max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
+ # When enable Mamba prefix caching, `block_size` will be aligned
+ # across full attention layers and Mamba layers to ensure the
+ # prefix hit length aligned at block
+ if (
+ block_size != alignment_tokens # Faster for common case.
+ and (i + 1) * block_size % alignment_tokens != 0
+ ):
+ continue
for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0])
@@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
- block_hashes: list[BlockHash],
+ block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
+ alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py
index b480ac78f23cf..ef0f8d9e67452 100644
--- a/vllm/v1/cudagraph_dispatcher.py
+++ b/vllm/v1/cudagraph_dispatcher.py
@@ -4,6 +4,9 @@ from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
class CudagraphDispatcher:
@@ -28,7 +31,11 @@ class CudagraphDispatcher:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
- self.cudagraph_mode = self.compilation_config.cudagraph_mode
+ self.uniform_decode_query_len = (
+ 1
+ if not self.vllm_config.speculative_config
+ else 1 + self.vllm_config.speculative_config.num_speculative_tokens
+ )
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
@@ -36,25 +43,42 @@ class CudagraphDispatcher:
CUDAGraphMode.FULL: set(),
}
- not_use_piecewise_compilation = (
- not self.cudagraph_mode.requires_piecewise_compilation()
- )
-
assert (
- not_use_piecewise_compilation
+ not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
or self.compilation_config.is_attention_compiled_piecewise()
), (
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
"cudagraph_mode piecewise cudagraphs is used, "
"and attention should be in splitting_ops or "
"inductor splitting should be used. "
- f"cudagraph_mode={self.cudagraph_mode}, "
+ f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
f"compilation_mode={self.compilation_config.mode}, "
f"splitting_ops={self.compilation_config.splitting_ops}"
)
self.keys_initialized = False
+ def _create_padded_batch_descriptor(
+ self, num_tokens: int, uniform_decode: bool, has_lora: bool
+ ) -> BatchDescriptor:
+ max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
+ uniform_decode_query_len = self.uniform_decode_query_len
+ num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
+
+ if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
+ num_reqs = num_tokens_padded // uniform_decode_query_len
+ assert num_tokens_padded % uniform_decode_query_len == 0
+ else:
+ uniform_decode = False
+ num_reqs = min(num_tokens_padded, max_num_seqs)
+
+ return BatchDescriptor(
+ num_tokens=num_tokens_padded,
+ num_reqs=num_reqs,
+ uniform=uniform_decode,
+ has_lora=has_lora,
+ )
+
def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
@@ -66,7 +90,9 @@ class CudagraphDispatcher:
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
):
- # This should be called only after attention backend is initialized.
+ # This should be called only after attention backend is initialized. So we can
+ # get the correct cudagraph mode after backend support is resolved.
+ self.cudagraph_mode = cudagraph_mode
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
@@ -86,9 +112,9 @@ class CudagraphDispatcher:
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
- BatchDescriptor(
- num_tokens=bs, uniform_decode=False, has_lora=has_lora
- ),
+ self._create_padded_batch_descriptor(
+ bs, False, has_lora
+ ).relax_for_mixed_batch_cudagraphs(),
)
# if decode cudagraph mode is FULL, and we don't already have mixed
@@ -109,40 +135,49 @@ class CudagraphDispatcher:
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
- BatchDescriptor(
- num_tokens=bs, uniform_decode=True, has_lora=has_lora
- ),
+ self._create_padded_batch_descriptor(bs, True, has_lora),
)
+
self.keys_initialized = True
def dispatch(
- self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
- ) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
+ self,
+ num_tokens: int,
+ uniform_decode: bool,
+ has_lora: bool,
+ use_cascade_attn: bool = False,
+ ) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
- # if not initialized, just skip dispatching.
- if not self.keys_initialized:
- return CUDAGraphMode.NONE, None
+ if (
+ not self.keys_initialized
+ or self.cudagraph_mode == CUDAGraphMode.NONE
+ or num_tokens > self.compilation_config.max_cudagraph_capture_size
+ ):
+ return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
+
+ batch_desc = self._create_padded_batch_descriptor(
+ num_tokens, uniform_decode, has_lora
+ )
+ relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
- non_uniform_key = batch_descriptor.non_uniform
- # if a batch use cascade attention, bypass checking full cudagraphs
if not use_cascade_attn:
# check if key exists for full cudagraph
- if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
- return CUDAGraphMode.FULL, batch_descriptor
+ if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
+ return CUDAGraphMode.FULL, batch_desc
- # otherwise, check if non-uniform key exists
- if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
- return CUDAGraphMode.FULL, non_uniform_key
+ # otherwise, check if the relaxed key exists
+ if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
+ return CUDAGraphMode.FULL, relaxed_batch_desc
- # also check if non-uniform key exists for more "general"
+ # also check if the relaxed key exists for more "general"
# piecewise cudagraph
- if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
- return CUDAGraphMode.PIECEWISE, non_uniform_key
+ if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
+ return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
- # finally, just return no cudagraphs
- return CUDAGraphMode.NONE, None
+ # finally, just return no cudagraphs and a trivial batch descriptor
+ return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py
index 3f621d77c0241..ce2aae77108da 100644
--- a/vllm/v1/engine/__init__.py
+++ b/vllm/v1/engine/__init__.py
@@ -72,6 +72,14 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None
+ @property
+ def params(self) -> SamplingParams | PoolingParams:
+ """Return the processed params (sampling or pooling)."""
+ if self.sampling_params is not None:
+ return self.sampling_params
+ assert self.pooling_params is not None
+ return self.pooling_params
+
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py
index c64b3cccfc652..827a2736af284 100644
--- a/vllm/v1/engine/async_llm.py
+++ b/vllm/v1/engine/async_llm.py
@@ -31,7 +31,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
-from vllm.utils.func_utils import deprecate_kwargs
from vllm.utils.math_utils import cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
@@ -195,12 +194,6 @@ class AsyncLLM(EngineClient):
self.profiler = None
@classmethod
- @deprecate_kwargs(
- "disable_log_requests",
- additional_message=(
- "This argument will have no effect. Use `enable_log_requests` instead."
- ),
- )
def from_vllm_config(
cls,
vllm_config: VllmConfig,
@@ -213,7 +206,6 @@ class AsyncLLM(EngineClient):
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
- disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM":
# Create the LLMEngine.
return cls(
@@ -321,14 +313,15 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
+ # Use cloned params that may have been updated in process_inputs()
+ params = request.params
+
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
return queue
- # Get the updated SamplingParams from the request, which
- # were cloned/updated in processor.process_inputs above.
- parent_params = request.sampling_params
- assert parent_params is not None
+ parent_params = params
+ assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index e403cea87788b..dffe05445ee46 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -250,6 +250,9 @@ class LLMEngine:
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
+ # Use cloned params that may have been updated in process_inputs()
+ params = request.params
+
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
@@ -262,10 +265,10 @@ class LLMEngine:
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
- request_id, params = parent_req.get_child_info(idx)
+ request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
- child_request.sampling_params = params
+ child_request.sampling_params = child_params
# Make a new RequestState and queue.
self.output_processor.add_request(
diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py
index 86747299eb107..2f2e85c0ff332 100644
--- a/vllm/v1/kv_offload/cpu.py
+++ b/vllm/v1/kv_offload/cpu.py
@@ -4,7 +4,7 @@ from collections.abc import Iterator
import torch
-from vllm.attention import AttentionBackend
+from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py
index c1813a4ff4ea9..2cdd5ba5ffe5c 100644
--- a/vllm/v1/kv_offload/spec.py
+++ b/vllm/v1/kv_offload/spec.py
@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
import torch
+from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING:
- from vllm.attention import AttentionBackend
from vllm.config import VllmConfig
logger = init_logger(__name__)
@@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
def get_handlers(
self,
kv_caches: dict[str, torch.Tensor],
- attn_backends: dict[str, type["AttentionBackend"]],
+ attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
"""
Get offloading handlers along with their respective src and dst types.
diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py
index bb163f0043fc6..461458c1f6ce8 100644
--- a/vllm/v1/kv_offload/worker/cpu_gpu.py
+++ b/vllm/v1/kv_offload/worker/cpu_gpu.py
@@ -5,7 +5,7 @@ import numpy as np
import torch
from vllm import _custom_ops as ops
-from vllm.attention import AttentionBackend
+from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py
index e2d82241ce210..bd18a152ffc08 100644
--- a/vllm/v1/metrics/loggers.py
+++ b/vllm/v1/metrics/loggers.py
@@ -440,57 +440,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
# Setting default values
self.record_sleep_state()
- # GPU cache
- #
- # Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- gauge_gpu_cache_usage = self._gauge_cls(
- name="vllm:gpu_cache_usage_perc",
- documentation=(
- "GPU KV-cache usage. 1 means 100 percent usage."
- "DEPRECATED: Use vllm:kv_cache_usage_perc instead."
- ),
- multiprocess_mode="mostrecent",
- labelnames=labelnames,
- )
- self.gauge_gpu_cache_usage = make_per_engine(
- gauge_gpu_cache_usage, engine_indexes, model_name
- )
-
- # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- counter_gpu_prefix_cache_queries = self._counter_cls(
- name="vllm:gpu_prefix_cache_queries",
- documentation=(
- "GPU prefix cache queries, in terms of number of queried"
- "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."
- ),
- labelnames=labelnames,
- )
- self.counter_gpu_prefix_cache_queries = make_per_engine(
- counter_gpu_prefix_cache_queries, engine_indexes, model_name
- )
-
- # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- counter_gpu_prefix_cache_hits = self._counter_cls(
- name="vllm:gpu_prefix_cache_hits",
- documentation=(
- "GPU prefix cache hits, in terms of number of cached "
- "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."
- ),
- labelnames=labelnames,
- )
- self.counter_gpu_prefix_cache_hits = make_per_engine(
- counter_gpu_prefix_cache_hits, engine_indexes, model_name
- )
-
gauge_kv_cache_usage = self._gauge_cls(
name="vllm:kv_cache_usage_perc",
documentation="KV-cache usage. 1 means 100 percent usage.",
@@ -735,39 +684,41 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
)
# Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
- # TODO: in 0.12, only enable if show_hidden_metrics=True
- histogram_time_per_output_token = self._histogram_cls(
- name="vllm:time_per_output_token_seconds",
- documentation=(
- "Histogram of time per output token in seconds."
- "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
- ),
- buckets=[
- 0.01,
- 0.025,
- 0.05,
- 0.075,
- 0.1,
- 0.15,
- 0.2,
- 0.3,
- 0.4,
- 0.5,
- 0.75,
- 1.0,
- 2.5,
- 5.0,
- 7.5,
- 10.0,
- 20.0,
- 40.0,
- 80.0,
- ],
- labelnames=labelnames,
- )
- self.histogram_time_per_output_token = make_per_engine(
- histogram_time_per_output_token, engine_indexes, model_name
- )
+ # With 0.12.x you can enable with --show-hidden-metrics-for-version=0.11
+ # TODO: remove in 0.13.0
+ if self.show_hidden_metrics:
+ histogram_time_per_output_token = self._histogram_cls(
+ name="vllm:time_per_output_token_seconds",
+ documentation=(
+ "Histogram of time per output token in seconds."
+ "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
+ ),
+ buckets=[
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.075,
+ 0.1,
+ 0.15,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 20.0,
+ 40.0,
+ 80.0,
+ ],
+ labelnames=labelnames,
+ )
+ self.histogram_time_per_output_token = make_per_engine(
+ histogram_time_per_output_token, engine_indexes, model_name
+ )
histogram_inter_token_latency = self._histogram_cls(
name="vllm:inter_token_latency_seconds",
@@ -966,20 +917,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.gauge_scheduler_waiting[engine_idx].set(
scheduler_stats.num_waiting_reqs
)
- if self.show_hidden_metrics:
- self.gauge_gpu_cache_usage[engine_idx].set(
- scheduler_stats.kv_cache_usage
- )
self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)
- if self.show_hidden_metrics:
- self.counter_gpu_prefix_cache_queries[engine_idx].inc(
- scheduler_stats.prefix_cache_stats.queries
- )
- self.counter_gpu_prefix_cache_hits[engine_idx].inc(
- scheduler_stats.prefix_cache_stats.hits
- )
-
self.counter_prefix_cache_queries[engine_idx].inc(
scheduler_stats.prefix_cache_stats.queries
)
@@ -1050,7 +989,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.histogram_time_to_first_token[engine_idx].observe(ttft)
for itl in iteration_stats.inter_token_latencies_iter:
self.histogram_inter_token_latency[engine_idx].observe(itl)
- self.histogram_time_per_output_token[engine_idx].observe(itl)
+ if self.show_hidden_metrics:
+ self.histogram_time_per_output_token[engine_idx].observe(itl)
for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason][
diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py
index 926305d25f56b..ccaf07e18c468 100644
--- a/vllm/v1/sample/rejection_sampler.py
+++ b/vllm/v1/sample/rejection_sampler.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Sequence
from dataclasses import replace
import torch
@@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
- ) -> list[list[int]]:
+ discard_req_indices: Sequence[int] = (),
+ return_cu_num_tokens: bool = False,
+ ) -> tuple[list[list[int]], list[int] | None]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
@@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
+ discard_req_indices: Optional row indices to discard tokens in.
+ return_cu_num_tokens: Whether to also return cumulative token counts.
Returns:
A list of lists of token IDs.
"""
@@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
+ cu_num_tokens = None
+ if return_cu_num_tokens:
+ cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
+ if len(discard_req_indices) > 0:
+ valid_mask[discard_req_indices] = False
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
- return outputs
+ return outputs, cu_num_tokens
def apply_logits_processors(
self,
diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py
index 8f0463c76ce15..6d992bb37a59d 100644
--- a/vllm/v1/sample/tpu/sampler.py
+++ b/vllm/v1/sample/tpu/sampler.py
@@ -181,7 +181,7 @@ def apply_top_k_top_p(
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
- Note: in the case of tie (i.e. multipple cut-off elements present in the
+ Note: in the case of tie (i.e. multiple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index 3de418f1d13c8..7600df48150ac 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -8,6 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
+from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CompilationMode,
CUDAGraphMode,
@@ -40,6 +41,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
+from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
logger = init_logger(__name__)
@@ -65,6 +67,7 @@ class EagleProposer:
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
+ self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
@@ -83,6 +86,9 @@ class EagleProposer:
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
+ self.eagle3_use_aux_hidden_state: bool = (
+ self._get_eagle3_use_aux_hidden_state_from_config()
+ )
self.use_cuda_graph = False
@@ -152,8 +158,6 @@ class EagleProposer:
)
# Determine allowed attention backends once during initialization.
- from vllm.attention.backends.registry import AttentionBackendEnum
-
self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
@@ -268,15 +272,24 @@ class EagleProposer:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
+ num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=num_tokens,
+ num_tokens_padded=num_tokens,
+ )
+
cudagraph_runtime_mode = CUDAGraphMode.NONE
if (
self.use_cuda_graph
- and num_tokens <= self.compilation_config.max_cudagraph_capture_size
+ and num_tokens_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
):
- num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
- num_input_tokens = num_tokens
+ num_input_tokens = num_tokens_dp_padded
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[self.dp_rank] = num_input_tokens
+
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
@@ -300,6 +313,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@@ -362,15 +376,23 @@ class EagleProposer:
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
+ batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=batch_size,
+ num_tokens_padded=batch_size,
+ )
+
if (
self.use_cuda_graph
- and batch_size <= self.compilation_config.max_cudagraph_capture_size
+ and batch_size_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
):
- input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
+ input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
- input_batch_size = batch_size
+ input_batch_size = batch_size_dp_padded
cudagraph_runtime_mode = CUDAGraphMode.NONE
+ if batch_size_across_dp is not None:
+ batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
@@ -471,6 +493,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
+ num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@@ -1031,11 +1054,11 @@ class EagleProposer:
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
- and torch.allclose(
+ # TODO: Offload to CPU for comparison to avoid extra GPU memory
+ # usage in CI testing environments with limited GPU memory
+ and torch.equal(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
- rtol=1e-5,
- atol=1e-7,
)
):
share_embeddings = True
@@ -1081,8 +1104,11 @@ class EagleProposer:
hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor)
+ # TODO: Offload to CPU for comparison to avoid extra GPU memory
+ # usage in CI testing environments with limited GPU memory
and torch.equal(
- target_language_model.lm_head.weight, self.model.lm_head.weight
+ target_language_model.lm_head.weight.cpu(),
+ self.model.lm_head.weight.cpu(),
)
):
share_lm_head = True
@@ -1113,36 +1139,56 @@ class EagleProposer:
self,
num_tokens: int,
use_cudagraphs=True,
+ is_graph_capturing=False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
- if (
- cudagraphs_enabled
- and num_tokens <= self.compilation_config.max_cudagraph_capture_size
- ):
- num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
- with set_forward_context(
- None,
- self.vllm_config,
- num_tokens=num_tokens,
- cudagraph_runtime_mode=(
- CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
- ),
+ # FIXME: when using tree-based specdec, adjust number of forward-passes
+ # according to the depth of the tree.
+ for fwd_idx in range(
+ self.num_speculative_tokens if not is_graph_capturing else 1
):
- if self.supports_mm_inputs:
- input_ids = None
- inputs_embeds = self.inputs_embeds[:num_tokens]
- else:
- input_ids = self.input_ids[:num_tokens]
- inputs_embeds = None
+ if fwd_idx <= 1:
+ num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=num_tokens,
+ num_tokens_padded=num_tokens,
+ )
+ if (
+ cudagraphs_enabled
+ and num_tokens_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(
+ num_tokens_dp_padded
+ )
+ else:
+ num_input_tokens = num_tokens_dp_padded
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[self.dp_rank] = num_input_tokens
- self.model(
- input_ids=input_ids,
- positions=self._get_positions(num_tokens),
- hidden_states=self.hidden_states[:num_tokens],
- inputs_embeds=inputs_embeds,
- )
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
+ if cudagraphs_enabled
+ else CUDAGraphMode.NONE,
+ ):
+ if self.supports_mm_inputs:
+ input_ids = None
+ inputs_embeds = self.inputs_embeds[:num_input_tokens]
+ else:
+ input_ids = self.input_ids[:num_input_tokens]
+ inputs_embeds = None
+
+ self.model(
+ input_ids=input_ids,
+ positions=self._get_positions(num_input_tokens),
+ hidden_states=self.hidden_states[:num_input_tokens],
+ inputs_embeds=inputs_embeds,
+ )
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
@@ -1169,6 +1215,22 @@ class EagleProposer:
)
return builder
+ def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
+ """
+ Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
+ hidden states and directly uses the last layer output just like eagle1.
+ They might indicate this by setting "use_aux_hidden_state" to False
+ inside the "eagle_config" dict of their hf_config.
+ """
+ if self.method != "eagle3":
+ return False
+ # Assume that eagle3 heads use aux hidden states by default
+ use_aux_hidden_state = True
+ eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
+ if eagle_config is not None:
+ use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
+ return use_aux_hidden_state
+
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
@@ -1192,6 +1254,28 @@ class EagleProposer:
== 1
), "All eagle layers should belong to the same kv cache group"
+ def _pad_batch_across_dp(
+ self,
+ num_tokens_unpadded: int,
+ num_tokens_padded: int,
+ ) -> tuple[int, torch.Tensor]:
+ # TODO(Flechman): support DBO ubatching
+ ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
+ num_tokens_unpadded=num_tokens_unpadded,
+ parallel_config=self.vllm_config.parallel_config,
+ allow_microbatching=False,
+ allow_dp_padding=self.use_cuda_graph,
+ num_tokens_padded=num_tokens_padded,
+ uniform_decode=None,
+ num_scheduled_tokens_per_request=None,
+ )
+ assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
+
+ num_tokens_dp_padded = num_tokens_padded
+ if num_toks_across_dp is not None:
+ num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
+ return num_tokens_dp_padded, num_toks_across_dp
+
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py
index 464fbf11a21ad..c1509de821b05 100644
--- a/vllm/v1/worker/dp_utils.py
+++ b/vllm/v1/worker/dp_utils.py
@@ -9,6 +9,7 @@ from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
+ UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
create_ubatch_slices,
@@ -23,12 +24,14 @@ def _get_device_and_group(parallel_config: ParallelConfig):
device = get_dp_group().device
group = get_dp_group().device_group
- # Transfering this tensor from GPU to CPU will introduce a GPU sync
+ # Transferring this tensor from GPU to CPU will introduce a GPU sync
# point that could adversely affect performance of vllm with asynch
# scheduling. This environment variable exists to quickly disable
# this optimization if we run into this case.
if parallel_config.disable_nccl_for_dp_synchronization:
- logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.")
+ logger.info_once(
+ "Using CPU all reduce to synchronize DP padding between ranks."
+ )
device = "cpu"
group = get_dp_group().cpu_group
return device, group
@@ -88,6 +91,17 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()
+# This just pads the second ubatch slice out to the total number of tokens
+# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
+def _pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
+ padded_second_ubatch_slice = slice(
+ ubatch_slices[1].token_slice.start, num_total_tokens
+ )
+ ubatch_slices[1] = UBatchSlice(
+ padded_second_ubatch_slice, padded_second_ubatch_slice
+ )
+
+
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
@@ -220,11 +234,14 @@ def coordinate_batch_across_dp(
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
- token_split_point = int(num_tokens_after_padding[0].item()) // 2
+ num_tokens_padded = int(num_tokens_after_padding[0].item())
+ token_split_point = int(num_tokens_padded) // 2
assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
+ ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
+ assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
return (ubatch_slices, num_tokens_after_padding)
diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py
index 421fb29a7f87f..f6bc607c1ae67 100644
--- a/vllm/v1/worker/gpu/async_utils.py
+++ b/vllm/v1/worker/gpu/async_utils.py
@@ -21,6 +21,9 @@ class AsyncOutput(AsyncModelRunnerOutput):
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
+ # NOTE(woosuk): We must retain references to the GPU tensors,
+ # as the copy operations are performed on a different CUDA stream than
+ # the one where the tensors were created.
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
@@ -51,7 +54,9 @@ class AsyncOutput(AsyncModelRunnerOutput):
)
else:
self.logprobs_tensors = None
- self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True)
+ self.num_sampled_tokens_cpu = num_sampled_tokens.to(
+ "cpu", non_blocking=True
+ )
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
@@ -63,7 +68,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
- num_sampled_tokens_np = self.num_sampled_tokens.numpy()
+ num_sampled_tokens_np = self.num_sampled_tokens_cpu.numpy()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py
index 4510a1c5ca1e9..5aa1a33d851cc 100644
--- a/vllm/v1/worker/gpu/attn_utils.py
+++ b/vllm/v1/worker/gpu/attn_utils.py
@@ -18,7 +18,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheConfig,
KVCacheSpec,
)
-from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.utils import bind_kv_cache
@@ -145,7 +144,8 @@ def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
- query_start_loc: CpuGpuBuffer,
+ query_start_loc_gpu: torch.Tensor,
+ query_start_loc_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
@@ -153,9 +153,7 @@ def build_attn_metadata(
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
- query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
- query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
- max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
+ max_query_len = int(query_start_loc_cpu.max())
seq_lens = seq_lens[:num_reqs]
seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens_np.max())
diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py
index ba783e2d0c6fb..eb8e610ae4710 100644
--- a/vllm/v1/worker/gpu/cudagraph_utils.py
+++ b/vllm/v1/worker/gpu/cudagraph_utils.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from unittest.mock import patch
+from collections.abc import Callable, Iterable
+from typing import Any
import numpy as np
import torch
@@ -16,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
+from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
@@ -31,117 +33,69 @@ class CudaGraphManager:
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
-
+ self.cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
- if self.compilation_config.cudagraph_capture_sizes is not None:
- cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
- # Limit the cudagraph sizes to the max decode batch size.
- self.cudagraph_sizes = [
- x for x in cudagraph_sizes if x <= self.max_num_reqs
- ]
- else:
- self.cudagraph_sizes = []
- self.padded_sizes = self._init_padded_sizes()
+ self.cudagraph_sizes = get_cudagraph_sizes(
+ self.compilation_config.cudagraph_capture_sizes,
+ self.max_num_reqs,
+ self.max_num_tokens,
+ self.cudagraph_mode,
+ )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
- def _init_padded_sizes(self) -> dict[int, int]:
- if not self.cudagraph_mode.has_full_cudagraphs():
- # Full cuda graphs are not used.
- return {}
- if not self.cudagraph_sizes:
- return {}
-
- padded_sizes: dict[int, int] = {}
- for i in range(1, self.cudagraph_sizes[-1] + 1):
- for x in self.cudagraph_sizes:
- if i <= x:
- padded_sizes[i] = x
- break
- return padded_sizes
-
def needs_capture(self) -> bool:
- return len(self.padded_sizes) > 0
+ return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> int | None:
- if not self.cudagraph_mode.has_full_cudagraphs():
- return None
- if self.cudagraph_mode != CUDAGraphMode.FULL:
- # TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
- all_decode = all(
- x == 1 for x in scheduler_output.num_scheduled_tokens.values()
- )
- if not all_decode:
- # Prefill is included.
- return None
- return self.padded_sizes.get(num_tokens_after_padding)
+ return get_cudagraph_size(
+ num_tokens_after_padding,
+ scheduler_output.num_scheduled_tokens.values(),
+ self.cudagraph_sizes,
+ self.cudagraph_mode,
+ )
def capture_graph(
self,
- batch_size: int,
+ num_tokens: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
- assert batch_size not in self.graphs
-
- # Prepare dummy inputs.
- input_ids = input_buffers.input_ids.gpu[:batch_size]
- positions = input_buffers.positions[:batch_size]
-
- input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
- input_buffers.query_start_loc.np[batch_size:] = batch_size
- input_buffers.query_start_loc.copy_to_gpu()
- # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
- # for seq_lens. This leads to a mismatch between seq_lens (GPU) and
- # seq_lens_np (CPU), which might cause issues in some attention backends.
- input_buffers.seq_lens[:batch_size] = 1
- input_buffers.seq_lens[batch_size:] = 0
-
- input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
- slot_mappings = block_tables.slot_mappings[:, :batch_size]
-
- attn_metadata = build_attn_metadata(
- attn_metadata_builders=attn_metadata_builders,
- num_reqs=batch_size,
- num_tokens=batch_size,
- query_start_loc=input_buffers.query_start_loc,
- seq_lens=input_buffers.seq_lens,
- seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
- num_computed_tokens_cpu=None, # FIXME
- block_tables=input_block_tables,
- slot_mappings=slot_mappings,
- kv_cache_config=kv_cache_config,
+ num_reqs = min(num_tokens, self.max_num_reqs)
+ input_ids = input_buffers.input_ids.gpu[:num_tokens]
+ positions = input_buffers.positions[:num_tokens]
+ attn_metadata = prepare_inputs_to_capture(
+ num_reqs,
+ num_tokens,
+ input_buffers,
+ block_tables,
+ attn_metadata_builders,
+ self.max_model_len,
+ kv_cache_config,
)
- if self.dp_size > 1:
- num_tokens_across_dp = torch.full(
- (self.dp_size,),
- batch_size,
- dtype=torch.int32,
- device="cpu",
- )
- else:
- num_tokens_across_dp = None
+ num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
- num_tokens=batch_size,
+ num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
@@ -153,13 +107,13 @@ class CudaGraphManager:
self.hidden_states = torch.empty_like(hidden_states)
# Capture the graph.
+ assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with (
- patch("torch.cuda.empty_cache", lambda: None),
set_forward_context(
attn_metadata,
self.vllm_config,
- num_tokens=batch_size,
+ num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
@@ -169,8 +123,8 @@ class CudaGraphManager:
input_ids=input_ids,
positions=positions,
)
- self.hidden_states[:batch_size] = hidden_states
- self.graphs[batch_size] = graph
+ self.hidden_states[:num_tokens] = hidden_states
+ self.graphs[num_tokens] = graph
@torch.inference_mode()
def capture(
@@ -181,25 +135,125 @@ class CudaGraphManager:
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
- assert self.needs_capture()
- # Capture larger graphs first.
- sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
- if is_global_first_rank():
- sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
+ capture_graphs(
+ self.cudagraph_sizes,
+ self.device,
+ self.capture_graph,
+ model=model,
+ input_buffers=input_buffers,
+ block_tables=block_tables,
+ attn_metadata_builders=attn_metadata_builders,
+ kv_cache_config=kv_cache_config,
+ )
- with graph_capture(device=self.device):
- for batch_size in sizes_to_capture:
- self.capture_graph(
- batch_size,
- model,
- input_buffers,
- block_tables,
- attn_metadata_builders,
- kv_cache_config,
- )
-
- def run(self, batch_size: int) -> torch.Tensor:
- assert batch_size in self.graphs
- self.graphs[batch_size].replay()
+ def run(self, num_tokens: int) -> torch.Tensor:
+ assert num_tokens in self.graphs
+ self.graphs[num_tokens].replay()
assert self.hidden_states is not None
- return self.hidden_states[:batch_size]
+ return self.hidden_states[:num_tokens]
+
+
+def get_cudagraph_sizes(
+ capture_sizes: list[int] | None,
+ max_num_reqs: int,
+ max_num_tokens: int,
+ cudagraph_mode: CUDAGraphMode,
+) -> dict[int, int]:
+ if not cudagraph_mode.has_full_cudagraphs():
+ return {}
+ if not capture_sizes:
+ return {}
+
+ capture_sizes = sorted(capture_sizes)
+ # Limit the capture sizes to the max number of requests or tokens.
+ upper_bound = (
+ max_num_reqs
+ if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
+ else max_num_tokens
+ )
+ capture_sizes = [x for x in capture_sizes if x <= upper_bound]
+ if not capture_sizes:
+ return {}
+
+ cudagraph_sizes: dict[int, int] = {}
+ for i in range(1, capture_sizes[-1] + 1):
+ for x in capture_sizes:
+ if i <= x:
+ cudagraph_sizes[i] = x
+ break
+ return cudagraph_sizes
+
+
+def get_cudagraph_size(
+ num_tokens_after_dp_padding: int,
+ num_tokens_per_request: Iterable[int],
+ cudagraph_sizes: dict[int, int],
+ cudagraph_mode: CUDAGraphMode,
+) -> int | None:
+ size = cudagraph_sizes.get(num_tokens_after_dp_padding)
+ if size is None:
+ # No CUDA graph for this size.
+ return None
+ if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
+ all_decode = all(x == 1 for x in num_tokens_per_request)
+ if not all_decode:
+ # Prefill is included.
+ return None
+ return size
+
+
+def capture_graphs(
+ cudagraph_sizes: dict[int, int],
+ device: torch.device,
+ capture_fn: Callable,
+ **capture_kwargs,
+) -> None:
+ # Capture larger graphs first.
+ sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
+ if is_global_first_rank():
+ sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
+
+ with graph_capture(device=device):
+ for size in sizes_to_capture:
+ capture_fn(size, **capture_kwargs)
+
+
+def prepare_inputs_to_capture(
+ num_reqs: int,
+ num_tokens: int,
+ input_buffers: InputBuffers,
+ block_tables: BlockTables,
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ max_model_len: int,
+ kv_cache_config: KVCacheConfig,
+) -> dict[str, Any]:
+ num_tokens_per_req = num_tokens // num_reqs
+ query_start_loc = input_buffers.query_start_loc
+ query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
+ query_start_loc.np[num_reqs:] = num_tokens
+ query_start_loc.copy_to_gpu()
+ seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
+ # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
+ # rather than max_model_len. This introduces a discrepancy between
+ # seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
+ # certain attention backends.
+ input_buffers.seq_lens[:num_reqs] = num_tokens
+ input_buffers.seq_lens[num_reqs:] = 0
+
+ input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
+ slot_mappings = block_tables.slot_mappings[:, :num_tokens]
+
+ attn_metadata = build_attn_metadata(
+ attn_metadata_builders=attn_metadata_builders,
+ num_reqs=num_reqs,
+ num_tokens=num_tokens,
+ query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
+ query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
+ seq_lens=input_buffers.seq_lens,
+ seq_lens_np=seq_lens_np,
+ num_computed_tokens_cpu=None, # FIXME
+ block_tables=input_block_tables,
+ slot_mappings=slot_mappings,
+ kv_cache_config=kv_cache_config,
+ )
+ return attn_metadata
diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py
index 9bfc7f25bef3a..d71d91d1e5cb8 100644
--- a/vllm/v1/worker/gpu/dp_utils.py
+++ b/vllm/v1/worker/gpu/dp_utils.py
@@ -20,3 +20,12 @@ def get_batch_metadata_across_dp(
tensor[1][dp_rank] = cudagraph_size
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1]
+
+
+def make_num_tokens_across_dp(
+ dp_size: int,
+ num_tokens: int,
+) -> torch.Tensor | None:
+ if dp_size == 1:
+ return None
+ return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py
index 7675cb45170b5..2a7048ae3c0e0 100644
--- a/vllm/v1/worker/gpu/input_batch.py
+++ b/vllm/v1/worker/gpu/input_batch.py
@@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Any
import numba
-import numba.types as types
import numpy as np
import torch
@@ -37,6 +36,9 @@ class InputBuffers:
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
+ # Spec decoding.
+ self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
+
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer(
@@ -144,80 +146,42 @@ class InputBatch:
)
-# NOTE: With the type annotations, this function is pre-compiled
-# before the first call.
-@numba.jit(
- [
- types.none(
- types.int32[:], # idx_mapping
- types.int32[:], # num_scheduled_tokens
- types.int32[:, :], # prefill_token_ids
- types.int32[:], # num_computed_prefill_tokens
- types.int32[:], # prefill_len
- types.int32[:], # input_ids
- types.int32[:], # query_start_loc
- )
- ],
- nopython=True,
- cache=True,
-)
+@numba.njit(cache=True)
def _prepare_prefill_inputs(
- idx_mapping: np.ndarray, # batch_idx -> req_idx
- num_scheduled_tokens: np.ndarray, # [B]
+ idx_mapping: np.ndarray, # [B]
+ query_lens: np.ndarray, # [B]
+ query_start_loc: np.ndarray, # [B + 1]
prefill_token_ids: np.ndarray, # [N, max_model_len]
num_computed_prefill_tokens: np.ndarray, # [N]
- prefill_len: np.ndarray, # [N]
input_ids: np.ndarray, # [num_input_tokens]
- query_start_loc: np.ndarray, # [B + 1]
) -> None:
- num_reqs = num_scheduled_tokens.shape[0]
- query_start_loc[0] = 0
-
- cu_num_tokens = 0
+ num_reqs = idx_mapping.shape[0]
+ query_starts = query_start_loc[:num_reqs]
+ query_ends = query_start_loc[1 : num_reqs + 1]
+ starts = num_computed_prefill_tokens[idx_mapping]
+ ends = starts + query_lens
for i in range(num_reqs):
- req_idx = idx_mapping[i]
- query_len = num_scheduled_tokens[i]
-
- start = num_computed_prefill_tokens[req_idx]
- end = min(start + query_len, prefill_len[req_idx])
- n = end - start
-
- start_idx = cu_num_tokens
- input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]
-
- cu_num_tokens = start_idx + query_len
- query_start_loc[i + 1] = cu_num_tokens
-
- # Pad the inputs for CUDA graphs.
- # Note: pad query_start_loc to be non-decreasing, as kernels
- # like FlashAttention requires that
- query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
+ input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
+ idx_mapping[i], starts[i] : ends[i]
+ ]
def prepare_prefill_inputs(
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
- total_num_tokens: int,
+ query_start_loc: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_prefill_tokens: np.ndarray,
- prefill_len: np.ndarray,
- input_ids: CpuGpuBuffer,
- query_start_loc: CpuGpuBuffer,
+ input_ids: np.ndarray,
) -> None:
_prepare_prefill_inputs(
idx_mapping,
num_scheduled_tokens,
+ query_start_loc,
prefill_token_ids,
num_computed_prefill_tokens,
- prefill_len,
- input_ids.np,
- query_start_loc.np,
+ input_ids,
)
- input_ids.copy_to_gpu(total_num_tokens)
- # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
- # tensors from CPU to GPU, because they may include paddings needed
- # for full CUDA graph mode.
- query_start_loc.copy_to_gpu()
@triton.jit
@@ -380,8 +344,8 @@ def _post_update_kernel(
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
+ num_rejected_ptr,
query_start_loc_ptr,
- cu_num_logits_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
@@ -396,17 +360,10 @@ def _post_update_kernel(
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
+ num_rejected = tl.load(num_rejected_ptr + req_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
- num_computed += query_len
- # Consider the rejected tokens in spec decoding.
- if num_sampled > 0:
- # NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
- logits_start = tl.load(cu_num_logits_ptr + req_id)
- logits_end = tl.load(cu_num_logits_ptr + req_id + 1)
- num_logits = logits_end - logits_start
- num_rejected = num_logits - num_sampled
- num_computed -= num_rejected
+ num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
@@ -421,10 +378,10 @@ def post_update(
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
- # [num_reqs + 1]
- cu_num_logits: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)](
@@ -434,7 +391,7 @@ def post_update(
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
+ num_rejected,
query_start_loc,
- cu_num_logits,
num_warps=1,
)
diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py
index 6e332ee4b75b8..0c9fdd0077f4a 100644
--- a/vllm/v1/worker/gpu/model_runner.py
+++ b/vllm/v1/worker/gpu/model_runner.py
@@ -35,7 +35,10 @@ from vllm.v1.worker.gpu.attn_utils import (
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
-from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
+from vllm.v1.worker.gpu.dp_utils import (
+ get_batch_metadata_across_dp,
+ make_num_tokens_across_dp,
+)
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
@@ -45,7 +48,11 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
-from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
+from vllm.v1.worker.gpu.spec_decode import init_speculator
+from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
+ get_num_rejected,
+ rejection_sample,
+)
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -97,16 +104,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event()
+ self.spec_decode_event = torch.cuda.Event()
else:
self.input_prep_event = None
self.structured_outputs_event = None
+ self.spec_decode_event = None
if self.speculative_config is not None:
self.do_spec_decode = True
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
+ self.speculator = init_speculator(self.vllm_config, self.device)
else:
self.do_spec_decode = False
self.num_speculative_steps = 0
+ self.speculator = None
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
@@ -129,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
- self.cudagraph_manager = CudaGraphManager(
- vllm_config=self.vllm_config,
- device=self.device,
- )
+ self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
def get_supported_tasks(self) -> tuple[str]:
return ("generate",)
@@ -153,6 +161,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
+ if self.do_spec_decode:
+ self.speculator.load_model(self.model)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
@@ -190,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
+ if self.do_spec_decode:
+ # HACK(woosuk)
+ self.speculator.set_attn(
+ self.kv_cache_config,
+ self.attn_metadata_builders,
+ self.block_tables,
+ )
+
# TODO(woosuk): Support other backends.
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
@@ -213,11 +231,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens = torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
)
+ query_start_loc = self.input_buffers.query_start_loc
+ query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1]
+ query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1]
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
- query_start_loc=self.input_buffers.query_start_loc,
+ query_start_loc_gpu=query_start_loc_gpu,
+ query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
seq_lens_np=input_batch.seq_lens_np,
num_computed_tokens_cpu=num_computed_tokens,
@@ -245,12 +267,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not skip_attn:
self.prepare_dummy_attn_metadata(input_batch)
- if self.dp_size == 1:
- num_tokens_across_dp: torch.Tensor | None = None
- else:
- num_tokens_across_dp = torch.full(
- (self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
- )
+ num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
with (
self.maybe_dummy_run_with_lora(
@@ -292,6 +309,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
skip_attn=True,
)
self._dummy_sampler_run(sample_hidden_states)
+ if self.do_spec_decode:
+ num_tokens_across_dp = make_num_tokens_across_dp(
+ self.dp_size, self.max_num_tokens
+ )
+ self.speculator.run_model(
+ self.max_num_tokens,
+ attn_metadata=None,
+ num_tokens_across_dp=num_tokens_across_dp,
+ )
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
@@ -325,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
)
+ if self.do_spec_decode:
+ self.speculator.capture_model()
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@@ -466,20 +494,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
- # Copy prefill tokens from CPU to GPU and get query_start_loc.
+ # Get query_start_loc.
+ np.cumsum(
+ num_scheduled_tokens,
+ out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
+ )
+ # Pad for full CUDA graph mode.
+ # Some attention backends like FA3 require query_start_loc to be non-decreasing.
+ self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
+ self.input_buffers.query_start_loc.copy_to_gpu()
+ query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
+ query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
+ query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
+
+ # Copy prefill tokens from CPU to GPU.
prepare_prefill_inputs(
idx_mapping_np,
num_scheduled_tokens,
- num_tokens,
+ query_start_loc_np,
self.req_states.prefill_token_ids,
self.req_states.num_computed_prefill_tokens,
- self.req_states.prefill_len.np,
- self.input_buffers.input_ids,
- self.input_buffers.query_start_loc,
+ self.input_buffers.input_ids.np,
)
- query_start_loc = self.input_buffers.query_start_loc
- query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
- query_start_loc_np = query_start_loc.np[: num_reqs + 1]
+ self.input_buffers.input_ids.copy_to_gpu(num_tokens)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
@@ -525,7 +562,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
- query_start_loc=self.input_buffers.query_start_loc,
+ query_start_loc_gpu=query_start_loc_gpu,
+ query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=num_computed_tokens,
@@ -562,7 +600,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
- ) -> tuple[SamplerOutput, torch.Tensor]:
+ ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
@@ -588,6 +626,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled = (~is_chunked_prefilling).int()
+ num_rejected = torch.zeros_like(num_sampled)
else:
# Draft tokens for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices]
@@ -598,9 +637,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_speculative_steps,
)
num_sampled *= ~is_chunked_prefilling
+ num_rejected = get_num_rejected(
+ input_batch.cu_num_logits,
+ num_sampled,
+ )
sampler_output.sampled_token_ids = sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
- return sampler_output, num_sampled
+ return sampler_output, num_sampled, num_rejected
def compute_prompt_logprobs(
self,
@@ -706,6 +749,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch,
sampled_tokens: torch.Tensor,
num_sampled: torch.Tensor,
+ num_rejected: torch.Tensor,
) -> None:
# Update the number of computed tokens.
post_update(
@@ -714,8 +758,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.last_sampled_tokens,
sampled_tokens,
num_sampled,
+ num_rejected,
input_batch.query_start_loc,
- input_batch.cu_num_logits,
)
# Update the number of computed prefill tokens.
@@ -727,6 +771,42 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.prefill_len.np[idx_mapping_np],
)
+ @torch.inference_mode()
+ def propose_draft(
+ self,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ last_hidden_states: torch.Tensor,
+ aux_hidden_states: list[torch.Tensor] | None,
+ num_sampled: torch.Tensor,
+ num_rejected: torch.Tensor,
+ ) -> torch.Tensor:
+ num_reqs = input_batch.num_reqs
+ idx_mapping_np = input_batch.idx_mapping_np
+ with async_barrier(self.spec_decode_event):
+ self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
+ self.req_states.prefill_token_ids[
+ idx_mapping_np,
+ self.req_states.num_computed_prefill_tokens[idx_mapping_np],
+ ]
+ )
+ next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
+ num_reqs
+ )
+
+ assert self.speculator is not None
+ draft_tokens = self.speculator.propose(
+ input_batch,
+ sampling_metadata,
+ last_hidden_states,
+ aux_hidden_states,
+ num_sampled,
+ num_rejected,
+ self.req_states.last_sampled_tokens,
+ next_prefill_tokens,
+ )
+ return draft_tokens
+
def get_cudagraph_and_dp_padding(
self,
scheduler_output: SchedulerOutput,
@@ -879,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
- sampler_output, num_sampled_tokens = self.sample(
+ sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
@@ -900,7 +980,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
- num_sampled_tokens=num_sampled_tokens,
+ num_sampled_tokens=num_sampled,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
@@ -911,8 +991,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish.
self.postprocess(
- input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
+ input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
+ if self.do_spec_decode:
+ draft_tokens = self.propose_draft(
+ input_batch,
+ sampling_metadata,
+ hidden_states,
+ None, # aux_hidden_states
+ num_sampled,
+ num_rejected,
+ )
+ self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
if self.use_async_scheduling:
return async_output
diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py
index c48ed2d8ca167..d8676079ab951 100644
--- a/vllm/v1/worker/gpu/sampler.py
+++ b/vllm/v1/worker/gpu/sampler.py
@@ -100,8 +100,9 @@ def _gumbel_sample_kernel(
mask=mask,
other=float("-inf"),
)
+ logits = logits.to(tl.float32)
- temp = tl.load(temp_ptr + req_idx)
+ temp = tl.load(temp_ptr + req_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
@@ -116,7 +117,7 @@ def _gumbel_sample_kernel(
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Use div_rn to match the behavior of torch.
- logits = tl.div_rn(logits, temp.to(tl.float32))
+ logits = tl.div_rn(logits, temp)
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
diff --git a/vllm/v1/worker/gpu/spec_decode/__init__.py b/vllm/v1/worker/gpu/spec_decode/__init__.py
index e69de29bb2d1d..15b85204e05ce 100644
--- a/vllm/v1/worker/gpu/spec_decode/__init__.py
+++ b/vllm/v1/worker/gpu/spec_decode/__init__.py
@@ -0,0 +1,18 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+from vllm.config import VllmConfig
+
+
+def init_speculator(
+ vllm_config: VllmConfig,
+ device: torch.device,
+):
+ speculative_config = vllm_config.speculative_config
+ assert speculative_config is not None
+ if speculative_config.use_eagle():
+ from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
+
+ return EagleSpeculator(vllm_config, device)
+ raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py
new file mode 100644
index 0000000000000..daf2775e8b92d
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/eagle.py
@@ -0,0 +1,567 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from vllm.config import VllmConfig
+from vllm.config.compilation import CUDAGraphMode
+from vllm.forward_context import set_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader import get_model
+from vllm.triton_utils import tl, triton
+from vllm.utils.platform_utils import is_pin_memory_available
+from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
+from vllm.v1.worker.gpu.block_table import BlockTables
+from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
+from vllm.v1.worker.gpu.sampler import gumbel_sample
+from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
+from vllm.v1.worker.gpu.states import SamplingMetadata
+
+logger = init_logger(__name__)
+
+
+class EagleSpeculator:
+ def __init__(self, vllm_config: VllmConfig, device: torch.device):
+ self.vllm_config = vllm_config
+ self.device = device
+
+ self.speculative_config = vllm_config.speculative_config
+ assert self.speculative_config is not None
+ self.method = self.speculative_config.method
+ self.num_speculative_steps = self.speculative_config.num_speculative_tokens
+ self.draft_model_config = self.speculative_config.draft_model_config
+
+ self.scheduler_config = vllm_config.scheduler_config
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
+ self.max_model_len = vllm_config.model_config.max_model_len
+ # We need to get the hidden size from the draft model config because
+ # the draft model's hidden size can be different from the target model's
+ # hidden size (e.g., Llama 3.3 70B).
+ self.hidden_size = self.draft_model_config.get_hidden_size()
+ self.vocab_size = self.draft_model_config.get_vocab_size()
+ self.pin_memory = is_pin_memory_available()
+ self.dtype = vllm_config.model_config.dtype
+
+ self.input_buffers = InputBuffers(
+ max_num_reqs=self.max_num_reqs,
+ max_num_tokens=self.max_num_tokens,
+ hidden_size=self.hidden_size,
+ vocab_size=self.vocab_size,
+ dtype=self.dtype,
+ device=device,
+ pin_memory=self.pin_memory,
+ )
+ self.hidden_states = torch.zeros(
+ self.max_num_tokens,
+ self.hidden_size,
+ dtype=self.dtype,
+ device=device,
+ )
+ self.temperature = torch.zeros(
+ self.max_num_reqs,
+ dtype=torch.float32,
+ device=device,
+ )
+ self.seeds = torch.zeros(
+ self.max_num_reqs,
+ dtype=torch.int64,
+ device=device,
+ )
+ self.draft_tokens = torch.zeros(
+ self.max_num_reqs,
+ self.num_speculative_steps,
+ dtype=torch.int64,
+ device=device,
+ )
+
+ self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
+
+ def load_model(self, target_model: nn.Module) -> None:
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("eagle_head"):
+ self.model = get_model(
+ vllm_config=self.vllm_config, model_config=self.draft_model_config
+ )
+
+ share_lm_head = True
+ if share_lm_head and hasattr(target_model, "lm_head"):
+ if hasattr(self.model, "lm_head"):
+ del self.model.lm_head
+ self.model.lm_head = target_model.lm_head
+
+ def set_attn(
+ self,
+ kv_cache_config: KVCacheConfig,
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ block_tables: BlockTables,
+ ) -> None:
+ self.kv_cache_config = kv_cache_config
+ self.attn_metadata_builders = attn_metadata_builders
+ self.block_tables = block_tables
+
+ @torch.inference_mode()
+ def run_model(
+ self,
+ num_tokens: int,
+ attn_metadata: dict[str, Any],
+ num_tokens_across_dp: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ with set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ num_tokens_across_dp=num_tokens_across_dp,
+ ):
+ ret_hidden_states = self.model(
+ input_ids=self.input_buffers.input_ids.gpu[:num_tokens],
+ positions=self.input_buffers.positions[:num_tokens],
+ hidden_states=self.hidden_states[:num_tokens],
+ )
+ if self.method == "mtp":
+ last_hidden_states = ret_hidden_states
+ hidden_states = ret_hidden_states
+ else:
+ last_hidden_states, hidden_states = ret_hidden_states
+ return last_hidden_states, hidden_states
+
+ def generate_draft(
+ self,
+ num_reqs: int,
+ attn_metadata: dict[str, Any],
+ num_tokens_across_dp: torch.Tensor | None,
+ ) -> None:
+ pos = self.input_buffers.positions[:num_reqs]
+ query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
+ for step in range(1, self.num_speculative_steps):
+ # Run the eagle model.
+ last_hidden_states, hidden_states = self.run_model(
+ num_reqs, attn_metadata, num_tokens_across_dp
+ )
+ logits = self.model.compute_logits(last_hidden_states)
+
+ # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
+ # used for draft and target sampling.
+ draft_tokens = gumbel_sample(
+ logits,
+ self.temperature[:num_reqs],
+ self.seeds[:num_reqs],
+ pos + 1,
+ apply_temperature=True,
+ )
+ self.draft_tokens[:num_reqs, step] = draft_tokens
+
+ if step < self.num_speculative_steps - 1:
+ # Update the inputs for the next step.
+ update_eagle_inputs(
+ draft_tokens,
+ hidden_states,
+ self.input_buffers,
+ self.hidden_states,
+ self.max_model_len,
+ )
+ self.block_tables.compute_slot_mappings(query_start_loc, pos)
+
+ def capture_model(self) -> None:
+ if self.num_speculative_steps == 1:
+ return
+ logger.info("Capturing model for Eagle speculator...")
+ self.cudagraph_manager.capture(
+ self.generate_draft,
+ self.input_buffers,
+ self.block_tables,
+ self.attn_metadata_builders,
+ self.kv_cache_config,
+ )
+
+ @torch.inference_mode()
+ def propose(
+ self,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ # [num_tokens, hidden_size]
+ last_hidden_states: torch.Tensor,
+ # num_layers x [num_tokens, hidden_size]
+ aux_hidden_states: list[torch.Tensor] | None,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [max_num_reqs, 1]
+ last_sampled: torch.Tensor,
+ # [num_reqs]
+ next_prefill_tokens: torch.Tensor,
+ ) -> torch.Tensor:
+ # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
+ # number of rejected tokens, we maintain the size of eagle's input_ids and
+ # hidden_states the same as the target model's. This means, we pad each
+ # request's query length to include any rejected positions. By doing so,
+ # we can also reuse the attention metadata (e.g., query_start_loc,
+ # seq_lens) of the target model.
+ if aux_hidden_states:
+ assert self.method == "eagle3"
+ hidden_states = self.model.combine_hidden_states(
+ torch.cat(aux_hidden_states, dim=-1)
+ )
+ else:
+ hidden_states = last_hidden_states
+ num_tokens = input_batch.num_tokens_after_padding
+ self.hidden_states[:num_tokens] = hidden_states
+
+ # Get the input ids and last token indices for the speculator.
+ last_token_indices = prepare_eagle_inputs(
+ self.input_buffers,
+ input_batch,
+ num_sampled,
+ num_rejected,
+ last_sampled,
+ next_prefill_tokens,
+ )
+
+ # Prefill: Run the eagle speculator with eager mode.
+ # TODO(woosuk): Support CUDA graph for prefill.
+ last_hidden_states, hidden_states = self.run_model(
+ num_tokens,
+ input_batch.attn_metadata,
+ num_tokens_across_dp=None, # FIXME
+ )
+ sample_hidden_states = last_hidden_states[last_token_indices]
+ logits = self.model.compute_logits(sample_hidden_states)
+
+ num_reqs = input_batch.num_reqs
+ cu_num_logits = input_batch.cu_num_logits[:num_reqs]
+ # NOTE(woosuk): For draft sampling, we only consider the temperature
+ # and ignore the other sampling parameters such as top_k and top_p,
+ # for simplicity and performance.
+ # While this may slightly degrade the acceptance rate, it does not
+ # affect the output distribution after rejection sampling.
+ temperature = self.temperature[:num_reqs]
+ seeds = self.seeds[:num_reqs]
+ pos = self.input_buffers.positions[:num_reqs]
+ # Gather the values and copy them to the pre-allocated buffers.
+ torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
+ torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
+ torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
+ # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
+ # used for draft and target sampling.
+ draft_tokens = gumbel_sample(
+ logits, temperature, seeds, pos + 1, apply_temperature=True
+ )
+ if self.num_speculative_steps == 1:
+ # Early exit.
+ return draft_tokens.view(-1, 1)
+
+ # Save the draft tokens for the first step.
+ self.draft_tokens[:num_reqs, 0] = draft_tokens
+ # Prepare the inputs for the decode steps.
+ prepare_eagle_decode(
+ draft_tokens,
+ hidden_states,
+ last_token_indices,
+ input_batch.seq_lens,
+ num_rejected,
+ self.input_buffers,
+ self.hidden_states,
+ self.max_model_len,
+ self.max_num_reqs,
+ )
+ query_start_loc = self.input_buffers.query_start_loc
+ query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
+ slot_mappings = self.block_tables.compute_slot_mappings(
+ query_start_loc_gpu, pos
+ )
+
+ cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
+ if cudagraph_size is not None:
+ # Run CUDA graph.
+ self.cudagraph_manager.run(cudagraph_size)
+ return self.draft_tokens[:num_reqs]
+
+ # Run eager mode.
+ query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
+ query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
+ # HACK(woosuk)
+ seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
+ block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
+
+ # FIXME(woosuk): This is UNSAFE!!
+ attn_metadata = build_attn_metadata(
+ attn_metadata_builders=self.attn_metadata_builders,
+ num_reqs=num_reqs,
+ num_tokens=num_reqs,
+ query_start_loc_gpu=query_start_loc_gpu,
+ query_start_loc_cpu=query_start_loc_cpu,
+ seq_lens=self.input_buffers.seq_lens[:num_reqs],
+ seq_lens_np=seq_lens_np,
+ num_computed_tokens_cpu=None, # FIXME
+ block_tables=block_tables,
+ slot_mappings=slot_mappings,
+ kv_cache_config=self.kv_cache_config,
+ )
+ self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
+ return self.draft_tokens[:num_reqs]
+
+
+@triton.jit
+def _prepare_eagle_inputs_kernel(
+ last_token_indices_ptr,
+ eagle_input_ids_ptr,
+ eagle_positions_ptr,
+ target_input_ids_ptr,
+ target_positions_ptr,
+ idx_mapping_ptr,
+ last_sampled_ptr,
+ next_prefill_tokens_ptr,
+ num_sampled_ptr,
+ num_rejected_ptr,
+ query_start_loc_ptr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ query_start = tl.load(query_start_loc_ptr + batch_idx)
+ query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
+ query_len = query_end - query_start
+
+ # Get the true query length and next token after accounting for rejected tokens.
+ num_rejected = tl.load(num_rejected_ptr + batch_idx)
+ query_len -= num_rejected
+
+ num_sampled = tl.load(num_sampled_ptr + batch_idx)
+ if num_sampled > 0:
+ req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
+ next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
+ else:
+ # Chunked prefilling.
+ # Get the next prefill token.
+ next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
+
+ # Shift target_input_ids by one.
+ for i in range(1, query_len, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < query_len
+ input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
+ tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
+
+ last_token_index = query_start + query_len - 1
+ tl.store(last_token_indices_ptr + batch_idx, last_token_index)
+ tl.store(eagle_input_ids_ptr + last_token_index, next_token)
+
+ # Copy positions.
+ for i in range(0, query_len, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < query_len
+ target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
+ tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
+
+
+def prepare_eagle_inputs(
+ input_buffers: InputBuffers,
+ input_batch: InputBatch,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [max_num_reqs, 1]
+ last_sampled: torch.Tensor,
+ # [max_num_reqs]
+ next_prefill_tokens: torch.Tensor,
+) -> torch.Tensor:
+ num_reqs = input_batch.num_reqs
+ last_token_indices = torch.empty(
+ num_reqs,
+ dtype=torch.int64,
+ device=num_sampled.device,
+ )
+ _prepare_eagle_inputs_kernel[(num_reqs,)](
+ last_token_indices,
+ input_buffers.input_ids.gpu,
+ input_buffers.positions,
+ input_batch.input_ids,
+ input_batch.positions,
+ input_batch.idx_mapping,
+ last_sampled,
+ next_prefill_tokens,
+ num_sampled,
+ num_rejected,
+ input_batch.query_start_loc,
+ BLOCK_SIZE=1024,
+ )
+ return last_token_indices
+
+
+@triton.jit
+def _prepare_eagle_docode_kernel(
+ draft_tokens_ptr,
+ output_hidden_states_ptr,
+ output_hidden_states_stride,
+ last_token_indices_ptr,
+ target_seq_lens_ptr,
+ num_rejected_ptr,
+ input_ids_ptr,
+ positions_ptr,
+ input_hidden_states_ptr,
+ input_hidden_states_stride,
+ query_start_loc_ptr,
+ seq_lens_ptr,
+ hidden_size,
+ max_model_len,
+ max_num_reqs,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+ num_reqs = tl.num_programs(0) - 1
+ if req_idx == num_reqs:
+ # Compute query_start_loc. Pad it with the last query_start_loc
+ # for CUDA graphs.
+ for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ q = tl.where(block < num_reqs, block, num_reqs)
+ mask = block < max_num_reqs + 1
+ tl.store(query_start_loc_ptr + block, q, mask=mask)
+ # Pad seq_lens for CUDA graphs.
+ for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < max_num_reqs
+ tl.store(seq_lens_ptr + block, 0, mask=mask)
+ return
+
+ # draft token -> input id.
+ draft_token = tl.load(draft_tokens_ptr + req_idx)
+ tl.store(input_ids_ptr + req_idx, draft_token)
+
+ # output hidden states -> input hidden states.
+ src_idx = tl.load(last_token_indices_ptr + req_idx)
+ for i in range(0, hidden_size, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < hidden_size
+ output_hidden_states = tl.load(
+ output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
+ mask=mask,
+ )
+ tl.store(
+ input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
+ output_hidden_states,
+ mask=mask,
+ )
+
+ # Compute position and seq_lens.
+ # NOTE(woosuk): To prevent out-of-range access, we clamp these values
+ # if they reach the max model length.
+ position = tl.load(positions_ptr + req_idx)
+ position = tl.minimum(position + 1, max_model_len - 1)
+ tl.store(positions_ptr + req_idx, position)
+
+ target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
+ num_rejected = tl.load(num_rejected_ptr + req_idx)
+ seq_len = target_seq_len - num_rejected
+ seq_len = tl.minimum(seq_len + 1, max_model_len)
+ tl.store(seq_lens_ptr + req_idx, seq_len)
+
+
+def prepare_eagle_decode(
+ draft_tokens: torch.Tensor,
+ output_hidden_states: torch.Tensor,
+ last_token_indices: torch.Tensor,
+ target_seq_lens: torch.Tensor,
+ num_rejected: torch.Tensor,
+ input_buffers: InputBuffers,
+ input_hidden_states: torch.Tensor,
+ max_model_len: int,
+ max_num_reqs: int,
+):
+ num_reqs = draft_tokens.shape[0]
+ hidden_size = output_hidden_states.shape[-1]
+ _prepare_eagle_docode_kernel[(num_reqs + 1,)](
+ draft_tokens,
+ output_hidden_states,
+ output_hidden_states.stride(0),
+ last_token_indices,
+ target_seq_lens,
+ num_rejected,
+ input_buffers.input_ids.gpu,
+ input_buffers.positions,
+ input_hidden_states,
+ input_hidden_states.stride(0),
+ input_buffers.query_start_loc.gpu,
+ input_buffers.seq_lens,
+ hidden_size,
+ max_model_len,
+ max_num_reqs,
+ BLOCK_SIZE=1024,
+ )
+
+
+@triton.jit
+def _update_eagle_inputs_kernel(
+ input_ids_ptr,
+ positions_ptr,
+ input_hidden_states_ptr,
+ input_hidden_states_stride,
+ seq_lens_ptr,
+ max_model_len,
+ draft_tokens_ptr,
+ output_hidden_states_ptr,
+ output_hidden_states_stride,
+ hidden_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+
+ # Draft token -> Input ID.
+ draft_token = tl.load(draft_tokens_ptr + req_idx)
+ tl.store(input_ids_ptr + req_idx, draft_token)
+
+ # Output hidden states -> Input hidden states.
+ for i in range(0, hidden_size, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < hidden_size
+ output_hidden_states = tl.load(
+ output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
+ mask=mask,
+ )
+ tl.store(
+ input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
+ output_hidden_states,
+ mask=mask,
+ )
+
+ # Increment position and seq_lens.
+ # NOTE(woosuk): To prevent out-of-range access, we clamp these values
+ # if they reach the max model length.
+ position = tl.load(positions_ptr + req_idx)
+ position = tl.minimum(position + 1, max_model_len - 1)
+ tl.store(positions_ptr + req_idx, position)
+
+ seq_len = tl.load(seq_lens_ptr + req_idx)
+ seq_len = tl.minimum(seq_len + 1, max_model_len)
+ tl.store(seq_lens_ptr + req_idx, seq_len)
+
+
+def update_eagle_inputs(
+ draft_tokens: torch.Tensor,
+ output_hidden_states: torch.Tensor,
+ input_buffers: InputBuffers,
+ hidden_states: torch.Tensor,
+ max_model_len: int,
+):
+ num_reqs, hidden_size = output_hidden_states.shape
+ _update_eagle_inputs_kernel[(num_reqs,)](
+ input_buffers.input_ids.gpu,
+ input_buffers.positions,
+ hidden_states,
+ hidden_states.stride(0),
+ input_buffers.seq_lens,
+ max_model_len,
+ draft_tokens,
+ output_hidden_states,
+ output_hidden_states.stride(0),
+ hidden_size,
+ BLOCK_SIZE=1024,
+ )
diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
new file mode 100644
index 0000000000000..dcdeedda60a77
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Callable
+
+import torch
+
+from vllm.config import VllmConfig
+from vllm.config.compilation import CUDAGraphMode
+from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.worker.gpu.block_table import BlockTables
+from vllm.v1.worker.gpu.cudagraph_utils import (
+ capture_graphs,
+ get_cudagraph_sizes,
+ prepare_inputs_to_capture,
+)
+from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
+from vllm.v1.worker.gpu.input_batch import InputBuffers
+
+
+class EagleCudaGraphManager:
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ self.vllm_config = vllm_config
+ self.scheduler_config = vllm_config.scheduler_config
+ self.device = device
+
+ self.max_model_len = vllm_config.model_config.max_model_len
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
+ self.dp_size = vllm_config.parallel_config.data_parallel_size
+ self.compilation_config = vllm_config.compilation_config
+ assert self.compilation_config is not None
+
+ cudagraph_mode: CUDAGraphMode
+ if self.compilation_config.cudagraph_mode is None:
+ cudagraph_mode = CUDAGraphMode.NONE
+ else:
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+ if cudagraph_mode == CUDAGraphMode.FULL:
+ # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
+ cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
+
+ self.cudagraph_mode = cudagraph_mode
+
+ self.cudagraph_sizes = get_cudagraph_sizes(
+ self.compilation_config.cudagraph_capture_sizes,
+ self.max_num_reqs,
+ self.max_num_tokens,
+ self.cudagraph_mode,
+ )
+
+ self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
+ self.pool = torch.cuda.graph_pool_handle()
+
+ def get_cudagraph_size(self, num_tokens: int) -> int | None:
+ return self.cudagraph_sizes.get(num_tokens)
+
+ def capture_graph(
+ self,
+ num_tokens: int,
+ generate_fn: Callable,
+ input_buffers: InputBuffers,
+ block_tables: BlockTables,
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ kv_cache_config: KVCacheConfig,
+ ) -> None:
+ num_reqs = min(num_tokens, self.max_num_reqs)
+ attn_metadata = prepare_inputs_to_capture(
+ num_reqs,
+ num_tokens,
+ input_buffers,
+ block_tables,
+ attn_metadata_builders,
+ self.max_model_len,
+ kv_cache_config,
+ )
+ num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
+
+ # Warm up.
+ generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
+
+ # Capture the graph.
+ assert num_tokens not in self.graphs
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph, self.pool):
+ generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
+ self.graphs[num_tokens] = graph
+
+ @torch.inference_mode()
+ def capture(
+ self,
+ generate_fn: Callable,
+ input_buffers: InputBuffers,
+ block_tables: BlockTables,
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ kv_cache_config: KVCacheConfig,
+ ) -> None:
+ capture_graphs(
+ self.cudagraph_sizes,
+ self.device,
+ self.capture_graph,
+ generate_fn=generate_fn,
+ input_buffers=input_buffers,
+ block_tables=block_tables,
+ attn_metadata_builders=attn_metadata_builders,
+ kv_cache_config=kv_cache_config,
+ )
+
+ def run(self, num_tokens: int) -> None:
+ assert num_tokens in self.graphs
+ self.graphs[num_tokens].replay()
diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
index 8a7bf28bacbd4..43c6ac518bccc 100644
--- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
+++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
@@ -69,3 +69,15 @@ def rejection_sample(
num_warps=1,
)
return sampled, num_sampled
+
+
+@torch.compile(dynamic=True)
+def get_num_rejected(
+ cu_num_logits: torch.Tensor,
+ num_sampled: torch.Tensor,
+) -> torch.Tensor:
+ num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
+ num_rejected = num_logits - num_sampled
+ # No token is rejected for chunked prefills.
+ num_rejected *= num_sampled > 0
+ return num_rejected
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index d6fef450c028a..e7991baeaa1b8 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -43,6 +43,8 @@ class CachedRequestState:
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
+ xdrope_positions: torch.Tensor | None = None
+
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
@@ -525,7 +527,7 @@ class InputBatch:
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
- # instead, we need to temporiarily copy the data for one of the indices
+ # instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 6a54e02f861e9..6bff83658b45a 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -19,12 +19,13 @@ import torch.nn as nn
from tqdm import tqdm
import vllm.envs as envs
-from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
+ AttentionType,
MultipleOf,
)
+from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@@ -50,16 +51,21 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+from vllm.model_executor.layers.rotary_embedding import (
+ MRotaryEmbedding,
+ XDRotaryEmbedding,
+)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
SupportsMRoPE,
SupportsMultiModal,
+ SupportsXDRoPE,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription,
+ supports_xdrope,
)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling,
@@ -145,7 +151,6 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_utils import (
- UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
)
@@ -178,7 +183,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
- logprobs_tensors: torch.Tensor | None,
+ logprobs_tensors: LogprobsTensors | None,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
vocab_size: int,
@@ -214,28 +219,29 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
This function blocks until the copy is finished.
"""
+ max_gen_len = self.sampled_token_ids_cpu.shape[-1]
self.async_copy_ready_event.synchronize()
# Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids
- max_gen_len = self.sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
+ for i in self._invalid_req_indices:
+ valid_sampled_token_ids[i].clear()
+ cu_num_tokens = None
else:
- valid_sampled_token_ids = RejectionSampler.parse_output(
+ valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
+ self._invalid_req_indices,
+ return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
)
- for i in self._invalid_req_indices:
- valid_sampled_token_ids[i].clear()
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
- # NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
- # for async sched + spec decode + logprobs compatibility.
- output.logprobs = self._logprobs_tensors_cpu.tolists()
+ output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
return output
@@ -324,6 +330,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
+ self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@@ -375,7 +382,9 @@ class GPUModelRunner(
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
- self.use_aux_hidden_state_outputs = True
+ self.use_aux_hidden_state_outputs = (
+ self.drafter.eagle3_use_aux_hidden_state
+ )
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
@@ -467,6 +476,7 @@ class GPUModelRunner(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
+ self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
@@ -510,6 +520,13 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ # Similar to mrope but use assigned dimension number for RoPE, 4 as default.
+ self.xdrope_positions = self._make_buffer(
+ (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
+ )
+
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
@@ -591,10 +608,14 @@ class GPUModelRunner(
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(
@@ -770,6 +791,10 @@ class GPUModelRunner(
if self.uses_mrope:
self._init_mrope_positions(req_state)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._init_xdrope_positions(req_state)
+
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
@@ -985,6 +1010,19 @@ class GPUModelRunner(
)
)
+ def _init_xdrope_positions(self, req_state: CachedRequestState):
+ model = self.get_model()
+ xdrope_model = cast(SupportsXDRoPE, model)
+ assert req_state.prompt_token_ids is not None, (
+ "XD-RoPE requires prompt_token_ids to be available."
+ )
+ assert supports_xdrope(model), "XD-RoPE support is not implemented."
+
+ req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
+ req_state.prompt_token_ids,
+ req_state.mm_features,
+ )
+
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
@@ -1166,37 +1204,47 @@ class GPUModelRunner(
def _get_encoder_seq_lens(
self,
- scheduled_encoder_inputs: dict[str, list[int]],
+ num_scheduled_tokens: dict[str, int],
kv_cache_spec: KVCacheSpec,
num_reqs: int,
- ) -> np.ndarray | None:
+ ) -> tuple[torch.Tensor | None, np.ndarray | None]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
- return None
+ return None, None
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
- encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
- for req_id in scheduled_encoder_inputs:
+ for req_id in num_scheduled_tokens:
req_index = self.input_batch.req_id_to_index[req_id]
- encoder_seq_lens[req_index] = self.max_encoder_len
+ req_state = self.requests[req_id]
+ if req_state.mm_features is None:
+ self.encoder_seq_lens.np[req_index] = 0
+ continue
- return encoder_seq_lens
+ # Get the total number of encoder input tokens for running encoder requests
+ # whether encoding is finished or not so that cross-attention knows how
+ # many encoder tokens to attend to.
+ encoder_input_tokens = sum(
+ feature.mm_position.length for feature in req_state.mm_features
+ )
+ self.encoder_seq_lens.np[req_index] = encoder_input_tokens
+
+ self.encoder_seq_lens.copy_to_gpu(num_reqs)
+ encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
+ encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs]
+
+ return encoder_seq_lens, encoder_seq_lens_cpu
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: np.ndarray,
- max_num_scheduled_tokens: int,
) -> tuple[
torch.Tensor,
SpecDecodeMetadata | None,
- UBatchSlices | None,
- torch.Tensor | None,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
- ubatch_slices, num_tokens_across_dp,
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -1229,6 +1277,11 @@ class GPUModelRunner(
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
+ # Calculate XD-RoPE positions.
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._calc_xdrope_positions(scheduler_output)
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -1306,28 +1359,6 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
- num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
- num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
- uniform_decode = (
- max_num_scheduled_tokens == self.uniform_decode_query_len
- ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
-
- # Disable DP padding when running eager to avoid excessive padding when
- # running prefills. This lets us set enforce_eager on the prefiller in
- # a P/D setup and still use CUDA graphs (enabled by this padding) on the
- # decoder.
- allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
-
- ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
- num_tokens_unpadded=num_tokens_unpadded,
- parallel_config=self.parallel_config,
- allow_microbatching=True,
- allow_dp_padding=allow_dp_padding,
- num_tokens_padded=num_tokens_padded,
- uniform_decode=uniform_decode,
- num_scheduled_tokens_per_request=num_scheduled_tokens,
- )
-
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
)
@@ -1362,6 +1393,12 @@ class GPUModelRunner(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
+ elif self.uses_xdrope_dim > 0:
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
+ self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True,
+ )
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
@@ -1422,25 +1459,28 @@ class GPUModelRunner(
return (
logits_indices,
spec_decode_metadata,
- ubatch_slices,
- num_tokens_across_dp,
)
def _build_attention_metadata(
self,
- total_num_scheduled_tokens: int,
- max_num_scheduled_tokens: int,
+ num_tokens: int,
num_reqs: int,
+ max_query_len: int,
+ num_tokens_padded: int | None = None,
+ num_reqs_padded: int | None = None,
ubatch_slices: UBatchSlices | None = None,
logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False,
for_cudagraph_capture: bool = False,
- scheduled_encoder_inputs: dict[str, list[int]] | None = None,
+ num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
+ num_tokens_padded = num_tokens_padded or num_tokens
+ num_reqs_padded = num_reqs_padded or num_reqs
+
logits_indices_padded = None
num_logits_indices = None
if logits_indices is not None:
@@ -1458,28 +1498,13 @@ class GPUModelRunner(
self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size,
)
- self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
+ self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
+ self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
- # Used in the below loop
- query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
- query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
- seq_lens = self.seq_lens.gpu[:num_reqs]
- seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
- num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
- :num_reqs
- ]
-
- dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
- if self.dcp_world_size > 1:
- dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
- dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
-
- spec_decode_common_attn_metadata = None
-
if for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need
# to make sure the backend see a max_seq_len that is larger to the sliding
@@ -1495,38 +1520,55 @@ class GPUModelRunner(
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
+ # Used in the below loop, uses padded shapes
+ query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
+ query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
+ seq_lens = self.seq_lens.gpu[:num_reqs_padded]
+ seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
+ num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
+ :num_reqs_padded
+ ]
+
+ dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
+ if self.dcp_world_size > 1:
+ dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
+ dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
+
+ spec_decode_common_attn_metadata = None
+
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_gid, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups
):
- encoder_seq_lens = self._get_encoder_seq_lens(
- scheduled_encoder_inputs or {},
+ encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
+ num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
- num_reqs,
+ num_reqs_padded,
)
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
- (num_reqs, 1),
+ (num_reqs_padded, 1),
dtype=torch.int32,
device=self.device,
)
slot_mapping = torch.zeros(
- (total_num_scheduled_tokens,),
+ (num_tokens_padded,),
dtype=torch.int64,
device=self.device,
)
else:
blk_table = self.input_batch.block_table[kv_cache_gid]
- blk_table_tensor = blk_table.get_device_tensor(num_reqs)
- slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
+ blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
+ slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
- # graph mode.
- blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
+ # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
+ slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
+ blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
@@ -1534,9 +1576,9 @@ class GPUModelRunner(
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
- num_reqs=num_reqs,
- num_actual_tokens=total_num_scheduled_tokens,
- max_query_len=max_num_scheduled_tokens,
+ num_actual_tokens=num_tokens_padded,
+ num_reqs=num_reqs_padded,
+ max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
@@ -1544,6 +1586,7 @@ class GPUModelRunner(
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=encoder_seq_lens,
+ encoder_seq_lens_cpu=encoder_seq_lens_cpu,
dcp_local_seq_lens=dcp_local_seq_lens,
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
)
@@ -1566,9 +1609,11 @@ class GPUModelRunner(
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict(
- num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs],
+ num_accepted_tokens=self.num_accepted_tokens.gpu[
+ :num_reqs_padded
+ ],
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
- :num_reqs
+ :num_reqs_padded
],
)
@@ -1607,11 +1652,22 @@ class GPUModelRunner(
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
+ if spec_decode_common_attn_metadata is not None and (
+ num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
+ ):
+ # Currently the drafter still only uses piecewise cudagraphs (and modifies
+ # the attention metadata in directly), and therefore does not want to use
+ # padded attention metadata.
+ spec_decode_common_attn_metadata = (
+ spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
+ )
+
return attn_metadata, spec_decode_common_attn_metadata
def _compute_cascade_attn_prefix_lens(
self,
num_scheduled_tokens: np.ndarray,
+ num_computed_tokens: np.ndarray,
num_common_prefix_blocks: list[int],
) -> list[list[int]] | None:
"""
@@ -1634,6 +1690,7 @@ class GPUModelRunner(
# 0 if cascade attention should not be used
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
+ num_computed_tokens,
num_common_prefix_blocks[kv_cache_gid],
attn_group.kv_cache_spec,
attn_group.get_metadata_builder(),
@@ -1646,6 +1703,7 @@ class GPUModelRunner(
def _compute_cascade_attn_prefix_len(
self,
num_scheduled_tokens: np.ndarray,
+ num_computed_tokens: np.ndarray,
num_common_prefix_blocks: int,
kv_cache_spec: KVCacheSpec,
attn_metadata_builder: AttentionMetadataBuilder,
@@ -1712,10 +1770,7 @@ class GPUModelRunner(
# and the second kernel will get an empty input. While this is not
# a fundamental problem, our current implementation does not support
# this case.
- num_reqs = len(num_scheduled_tokens)
- common_prefix_len = min(
- common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()
- )
+ common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (
common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
@@ -1791,6 +1846,53 @@ class GPUModelRunner(
mrope_pos_ptr += completion_part_len
+ def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
+ xdrope_pos_ptr = 0
+ for index, req_id in enumerate(self.input_batch.req_ids):
+ req = self.requests[req_id]
+ assert req.xdrope_positions is not None
+
+ num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
+ req.prompt_token_ids, req.prompt_embeds
+ )
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's xdrope_positions are pre-computed
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + prompt_part_len
+ src_start = num_computed_tokens
+ src_end = num_computed_tokens + prompt_part_len
+
+ self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
+ :, src_start:src_end
+ ]
+ xdrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's xdrope_positions on-the-fly
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + completion_part_len
+
+ XDRotaryEmbedding.get_next_input_positions_tensor(
+ out=self.xdrope_positions.np,
+ out_offset=dst_start,
+ context_len=num_computed_tokens + prompt_part_len,
+ num_new_tokens=completion_part_len,
+ )
+
+ xdrope_pos_ptr += completion_part_len
+
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
@@ -2035,6 +2137,7 @@ class GPUModelRunner(
req_start_idx = 0
should_sync_mrope_positions = False
+ should_sync_xdrope_positions = False
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
@@ -2108,6 +2211,10 @@ class GPUModelRunner(
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+ if should_sync_xdrope_positions:
+ self._calc_xdrope_positions(scheduler_output)
+ self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+
return mm_embeds, is_mm_embed
def get_model(self) -> nn.Module:
@@ -2217,19 +2324,6 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
- # This is where the second ubatch is adjusted to account for the padding.
- # Should be called after attention metadata creation. This just pads
- # the second ubatch slice out to the total number of tokens
- # (num_tokens + padding)
- @staticmethod
- def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
- padded_second_ubatch_slice = slice(
- ubatch_slices[1].token_slice.start, num_total_tokens
- )
- ubatch_slices[1] = UBatchSlice(
- padded_second_ubatch_slice, padded_second_ubatch_slice
- )
-
def _pool(
self,
hidden_states: torch.Tensor,
@@ -2274,18 +2368,7 @@ class GPUModelRunner(
pooler_output=pooler_output,
)
- def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
- if (
- self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
- and hasattr(self, "cudagraph_batch_sizes")
- and self.cudagraph_batch_sizes
- and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
- ):
- # Use CUDA graphs.
- # Add padding to the batch size.
- return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
-
- # Eager mode.
+ def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
@@ -2382,8 +2465,11 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
+
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
@@ -2479,28 +2565,24 @@ class GPUModelRunner(
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
- cu_num_new_tokens: list[int] | None = None
+ cu_num_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
+ # Mask out the sampled tokens that should not be sampled.
+ for i in discard_sampled_tokens_req_indices:
+ valid_sampled_token_ids[int(i)].clear()
else:
# Includes spec decode tokens.
- valid_sampled_token_ids = self.rejection_sampler.parse_output(
+ valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
+ discard_sampled_tokens_req_indices,
+ return_cu_num_tokens=logprobs_tensors is not None,
)
- if logprobs_tensors:
- # Needed for extracting logprobs when spec decoding.
- # This must be done prior to discarding sampled tokens.
- cu_num_new_tokens = [0]
- for toks in valid_sampled_token_ids:
- cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
- # Mask out the sampled tokens that should not be sampled.
- for i in discard_sampled_tokens_req_indices:
- valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
@@ -2554,7 +2636,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
- logprobs_tensors.tolists(cu_num_new_tokens)
+ logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
@@ -2622,6 +2704,87 @@ class GPUModelRunner(
**model_kwargs,
)
+ def _determine_batch_execution_and_padding(
+ self,
+ num_tokens: int,
+ num_reqs: int,
+ num_scheduled_tokens_np: np.ndarray,
+ max_num_scheduled_tokens: int,
+ use_cascade_attn: bool,
+ allow_microbatching: bool = True,
+ force_eager: bool = False,
+ # For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will
+ # be improved in model runner v2)
+ force_uniform_decode: bool | None = None,
+ force_has_lora: bool | None = None,
+ ) -> tuple[
+ CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
+ ]:
+ num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
+ uniform_decode = (
+ (
+ (max_num_scheduled_tokens == self.uniform_decode_query_len)
+ and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
+ )
+ if force_uniform_decode is None
+ else force_uniform_decode
+ )
+
+ has_lora = (
+ len(self.input_batch.lora_id_to_lora_request) > 0
+ if force_has_lora is None
+ else force_has_lora
+ )
+
+ dispatch_cudagraph = (
+ lambda num_tokens: self.cudagraph_dispatcher.dispatch(
+ num_tokens=num_tokens,
+ has_lora=has_lora,
+ use_cascade_attn=use_cascade_attn,
+ uniform_decode=uniform_decode,
+ )
+ if not force_eager
+ else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
+ )
+
+ cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
+ num_tokens_padded = batch_descriptor.num_tokens
+
+ # Extra coordination when running data-parallel since we need to coordinate
+ # across ranks
+ ubatch_slices, num_tokens_across_dp = None, None
+ if self.vllm_config.parallel_config.data_parallel_size > 1:
+ # Disable DP padding when running eager to avoid excessive padding when
+ # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
+ # in a P/D setup and still use CUDA graphs (enabled by this padding) on the
+ # decoder.
+ allow_dp_padding = (
+ self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
+ )
+
+ ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
+ num_tokens_unpadded=num_tokens_padded,
+ parallel_config=self.parallel_config,
+ allow_microbatching=allow_microbatching,
+ allow_dp_padding=allow_dp_padding,
+ num_tokens_padded=num_tokens_padded,
+ uniform_decode=uniform_decode,
+ num_scheduled_tokens_per_request=num_scheduled_tokens_np,
+ )
+
+ # Extract DP padding if there is any
+ if num_tokens_across_dp is not None:
+ dp_rank = self.parallel_config.data_parallel_rank
+ num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
+
+ # Re-dispatch with DP padding
+ cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
+ # Assert to make sure the agreed upon token count is correct otherwise
+ # num_tokens_across_dp will no-longer be valid
+ assert batch_descriptor.num_tokens == num_tokens_padded
+
+ return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
+
@torch.inference_mode()
def execute_model(
self,
@@ -2693,87 +2856,87 @@ class GPUModelRunner(
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
+ num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
(
logits_indices,
spec_decode_metadata,
- ubatch_slices,
- num_tokens_across_dp,
) = self._prepare_inputs(
- scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
+ scheduler_output,
+ num_scheduled_tokens_np,
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
- if self.cascade_attn_enabled and ubatch_slices is None:
+ if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
# Pre-compute cascade attention prefix lengths
- # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
+ self.input_batch.num_computed_tokens_cpu[:num_reqs],
scheduler_output.num_common_prefix_blocks,
)
- # TODO(lucas): move cudagraph dispatching here:
- # https://github.com/vllm-project/vllm/issues/23789
+ (
+ cudagraph_mode,
+ batch_desc,
+ ubatch_slices,
+ num_tokens_across_dp,
+ ) = self._determine_batch_execution_and_padding(
+ num_tokens=num_tokens_unpadded,
+ num_reqs=num_reqs,
+ num_scheduled_tokens_np=num_scheduled_tokens_np,
+ max_num_scheduled_tokens=max_num_scheduled_tokens,
+ use_cascade_attn=cascade_attn_prefix_lens is not None,
+ )
+
+ logger.debug(
+ "Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
+ "ubatch_slices: %s, num_tokens_across_dp: %s",
+ cudagraph_mode,
+ batch_desc,
+ ubatch_slices,
+ num_tokens_across_dp,
+ )
+
+ num_tokens_padded = batch_desc.num_tokens
+ num_reqs_padded = (
+ batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
+ )
- total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
- attn_metadata, spec_decode_common_attn_metadata = (
+ pad_attn = cudagraph_mode == CUDAGraphMode.FULL
+
+ (attn_metadata, spec_decode_common_attn_metadata) = (
self._build_attention_metadata(
- total_num_scheduled_tokens=total_num_scheduled_tokens,
- max_num_scheduled_tokens=max_num_scheduled_tokens,
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded if pad_attn else None,
+ max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
- scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
)
- dp_rank = self.parallel_config.data_parallel_rank
- if ubatch_slices:
- assert num_tokens_across_dp is not None
- num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
- self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
- elif num_tokens_across_dp is not None:
- num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
- else:
- num_input_tokens = self._get_num_input_tokens(
- scheduler_output.total_num_scheduled_tokens
- )
-
- (
- input_ids,
- inputs_embeds,
- positions,
- intermediate_tensors,
- model_kwargs,
- ec_connector_output,
- ) = self._preprocess(
- scheduler_output, num_input_tokens, intermediate_tensors
- )
-
- uniform_decode = (
- max_num_scheduled_tokens == self.uniform_decode_query_len
- ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
- batch_desc = BatchDescriptor(
- num_tokens=num_input_tokens,
- uniform_decode=uniform_decode,
- has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
- )
- cudagraph_runtime_mode, batch_descriptor = (
- self.cudagraph_dispatcher.dispatch(
- batch_desc,
- use_cascade_attn=cascade_attn_prefix_lens is not None,
- )
+ (
+ input_ids,
+ inputs_embeds,
+ positions,
+ intermediate_tensors,
+ model_kwargs,
+ ec_connector_output,
+ ) = self._preprocess(
+ scheduler_output, num_tokens_padded, intermediate_tensors
)
# Set cudagraph mode to none if calc_kv_scales is true.
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.
if self.calculate_kv_scales:
- cudagraph_runtime_mode = CUDAGraphMode.NONE
+ cudagraph_mode = CUDAGraphMode.NONE
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
@@ -2783,10 +2946,10 @@ class GPUModelRunner(
set_forward_context(
attn_metadata,
self.vllm_config,
- num_tokens=num_input_tokens,
+ num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp,
- cudagraph_runtime_mode=cudagraph_runtime_mode,
- batch_descriptor=batch_descriptor,
+ cudagraph_runtime_mode=cudagraph_mode,
+ batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
@@ -2836,7 +2999,7 @@ class GPUModelRunner(
if not get_pp_group().is_last_rank:
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
- self.vllm_config, num_input_tokens
+ self.vllm_config, num_tokens_padded
)
}
get_pp_group().send_tensor_dict(
@@ -3345,6 +3508,10 @@ class GPUModelRunner(
scope="local",
)
prepare_communication_buffer_for_model(self.model)
+ if (drafter := getattr(self, "drafter", None)) and (
+ drafter_model := getattr(drafter, "model", None)
+ ):
+ prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model())
@@ -3370,6 +3537,8 @@ class GPUModelRunner(
old_global_expert_indices,
rank_mapping,
)
+ if self.eplb_state.is_async:
+ self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
if (
self.vllm_config.compilation_config.mode
@@ -3642,6 +3811,7 @@ class GPUModelRunner(
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
+ is_graph_capturing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@@ -3718,52 +3888,44 @@ class GPUModelRunner(
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
- total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
+ num_tokens_unpadded = int(num_scheduled_tokens.sum())
+
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
- # Disable DP padding when running eager
- allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
-
- # We currently only microbatch if the number of tokens is
- # over a certain threshold.
- ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
- num_tokens_unpadded=total_num_scheduled_tokens,
- parallel_config=self.vllm_config.parallel_config,
- allow_microbatching=allow_microbatching,
- allow_dp_padding=allow_dp_padding,
- num_tokens_padded=total_num_scheduled_tokens,
- uniform_decode=uniform_decode,
- num_scheduled_tokens_per_request=num_scheduled_tokens,
- )
- num_tokens_after_padding = num_tokens
- if num_tokens_across_dp is not None:
- dp_rank = self.parallel_config.data_parallel_rank
- num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
-
- # filter out the valid batch descriptor
- _cg_mode, batch_descriptor = (
- self.cudagraph_dispatcher.dispatch(
- BatchDescriptor(
- num_tokens=num_tokens_after_padding,
- uniform_decode=uniform_decode,
- has_lora=activate_lora and self.lora_config is not None,
- )
+ _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
+ self._determine_batch_execution_and_padding(
+ num_tokens=num_tokens_unpadded,
+ num_reqs=num_reqs,
+ num_scheduled_tokens_np=num_scheduled_tokens,
+ max_num_scheduled_tokens=max_query_len,
+ use_cascade_attn=False,
+ allow_microbatching=allow_microbatching,
+ force_eager=is_profile
+ or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
+ # `force_uniform_decode` is used for cudagraph capture; because for
+ # capturing mixed prefill-decode batches, we sometimes use
+ # num_tokens == num_reqs which looks like a uniform decode batch to the
+ # dispatcher; but we actually want to capture a piecewise cudagraph
+ force_uniform_decode=uniform_decode,
+ # `force_has_lora` is used for cudagraph capture; because LoRA is
+ # activated later in the context manager, but we need to know the
+ # LoRA state when determining the batch descriptor for capture
+ force_has_lora=activate_lora,
)
- if not is_profile
- else (CUDAGraphMode.NONE, None)
)
- if cudagraph_runtime_mode is not None:
- # we allow forcing NONE when the dispatcher disagrees to support
- # warm ups for cudagraph capture
- assert (
- cudagraph_runtime_mode == CUDAGraphMode.NONE
- or cudagraph_runtime_mode == _cg_mode
- ), (
- f"Cudagraph runtime mode mismatch at dummy_run. "
- f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
- )
+
+ if cudagraph_runtime_mode is None:
+ cudagraph_runtime_mode = _cudagraph_mode
else:
- cudagraph_runtime_mode = _cg_mode
+ assert cudagraph_runtime_mode == _cudagraph_mode, (
+ f"Cudagraph runtime mode mismatch in dummy_run. "
+ f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}."
+ )
+
+ num_tokens_padded = batch_desc.num_tokens
+ num_reqs_padded = (
+ batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
+ )
attn_metadata: PerLayerAttnMetadata | None = None
@@ -3786,9 +3948,9 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu()
attn_metadata, _ = self._build_attention_metadata(
- total_num_scheduled_tokens=num_tokens,
- max_num_scheduled_tokens=max_query_len,
- num_reqs=num_reqs,
+ num_tokens=num_tokens_unpadded,
+ num_reqs=num_reqs_padded,
+ max_query_len=max_query_len,
ubatch_slices=ubatch_slices,
for_cudagraph_capture=True,
)
@@ -3801,27 +3963,29 @@ class GPUModelRunner(
remove_lora,
):
# Make sure padding doesn't exceed max_num_tokens
- assert num_tokens_after_padding <= self.max_num_tokens
- model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
+ assert num_tokens_padded <= self.max_num_tokens
+ model_kwargs = self._init_model_kwargs(num_tokens_padded)
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
input_ids = None
- inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
+ inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = {
**model_kwargs,
**self._dummy_mm_kwargs(num_reqs),
}
elif self.enable_prompt_embeds:
input_ids = None
- inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
- model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
+ inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
+ model_kwargs = self._init_model_kwargs(num_tokens_padded)
else:
- input_ids = self.input_ids.gpu[:num_tokens_after_padding]
+ input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
if self.uses_mrope:
- positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
+ positions = self.mrope_positions.gpu[:, :num_tokens_padded]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
else:
- positions = self.positions.gpu[:num_tokens_after_padding]
+ positions = self.positions.gpu[:num_tokens_padded]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -3836,26 +4000,26 @@ class GPUModelRunner(
)
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
- num_tokens_after_padding, None, False
+ num_tokens_padded, None, False
)
if ubatch_slices is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
- num_tokens_after_padding = ubatch_slices[0].num_tokens
+ num_tokens_padded = ubatch_slices[0].num_tokens
if num_tokens_across_dp is not None:
- num_tokens_across_dp[:] = num_tokens_after_padding
+ num_tokens_across_dp[:] = num_tokens_padded
with (
self.maybe_randomize_inputs(input_ids),
set_forward_context(
attn_metadata,
self.vllm_config,
- num_tokens=num_tokens_after_padding,
+ num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
- batch_descriptor=batch_descriptor,
+ batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
),
):
@@ -3875,7 +4039,7 @@ class GPUModelRunner(
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
use_cudagraphs = (
- cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
+ cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
@@ -3889,6 +4053,7 @@ class GPUModelRunner(
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
+ is_graph_capturing=is_graph_capturing,
)
# This is necessary to avoid blocking DP.
@@ -4121,14 +4286,18 @@ class GPUModelRunner(
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
- # (encode_budget, hidden_size) and scatter encoder
- # output into it.
+ # (max_tokens_for_modality, hidden_size) and scatter
+ # encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
- if encoder_output_shape[0] < encoder_budget:
+ max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
+ dummy_modality
+ ]
+ if encoder_output_shape[0] < max_mm_tokens_per_item:
+ encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
- (encoder_budget, encoder_output_shape[-1])
+ (max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
@@ -4321,6 +4490,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
+ is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)
@@ -4575,8 +4745,7 @@ class GPUModelRunner(
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
- cudagraph_mode = self.compilation_config.cudagraph_mode
- assert cudagraph_mode is not None
+ self.compilation_config.cudagraph_mode = cudagraph_mode
self.cudagraph_dispatcher.initialize_cudagraph_keys(
cudagraph_mode, self.uniform_decode_query_len
)
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 6a4bfde5f972b..d0c6091ce2a6e 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -8,12 +8,13 @@ from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
+import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
-from vllm.config import VllmConfig
+from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
@@ -487,6 +488,7 @@ class Worker(WorkerBase):
hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
@@ -534,12 +536,39 @@ class Worker(WorkerBase):
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
- num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
- all_gather_tensors = {
- "residual": not is_residual_scattered_for_sp(
- self.vllm_config, num_input_tokens
+ all_gather_tensors = {}
+ compilation_config = self.vllm_config.compilation_config
+ parallel_config = self.vllm_config.parallel_config
+
+ if (
+ parallel_config.pipeline_parallel_size > 1
+ and compilation_config.pass_config.enable_sequence_parallelism
+ and forward_pass
+ ):
+ # currently only supported by V1 GPUModelRunner
+ assert isinstance(self.model_runner, GPUModelRunner)
+ num_scheduled_tokens_np = np.array(
+ list(scheduler_output.num_scheduled_tokens.values()),
+ dtype=np.int32,
)
- }
+ # TODO(lucas): This is pretty gross; ideally we should only ever call
+ # `_determine_batch_execution_and_padding` once (will get called again
+ # in `execute_model`) but this requires a larger refactor of PP.
+ _, batch_desc, _, _ = (
+ self.model_runner._determine_batch_execution_and_padding(
+ num_tokens=num_scheduled_tokens,
+ num_reqs=len(num_scheduled_tokens_np),
+ num_scheduled_tokens_np=num_scheduled_tokens_np,
+ max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
+ use_cascade_attn=False, # TODO(lucas): Handle cascade attention
+ )
+ )
+ all_gather_tensors = {
+ "residual": not is_residual_scattered_for_sp(
+ self.vllm_config, batch_desc.num_tokens
+ )
+ }
+
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py
index ff047d8d03f0e..b799f1be73d9c 100644
--- a/vllm/v1/worker/kv_connector_model_runner_mixin.py
+++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py
@@ -13,7 +13,7 @@ from typing import (
import torch
-from vllm.attention import AttentionBackend
+from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import (
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index 72d4474b89627..9c1fbfd24149d 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -17,9 +17,8 @@ import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import vllm.envs as envs
-from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType
-from vllm.attention.layer import MLAAttention
+from vllm.attention.layer import Attention, MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
from vllm.config import (
diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py
index e1a109eca0a88..ce18ca6c37165 100644
--- a/vllm/v1/worker/tpu_worker.py
+++ b/vllm/v1/worker/tpu_worker.py
@@ -346,6 +346,6 @@ class TPUWorker:
if USE_TPU_INFERENCE:
- from tpu_inference.worker import TPUWorker as TpuInferenceWorker
+ from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
TPUWorker = TpuInferenceWorker # type: ignore
diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py
index 92e4ce3abdba3..bd88cb1b253f8 100644
--- a/vllm/v1/worker/utils.py
+++ b/vllm/v1/worker/utils.py
@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass, field
-from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
+from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
@@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
-if TYPE_CHECKING:
- from vllm.attention.layer import Attention
-
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
@@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
- forward_context: dict[str, "Attention"],
+ forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None: