diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh
index 56bb5cedaa0a9..df805e0850806 100755
--- a/.buildkite/scripts/annotate-release.sh
+++ b/.buildkite/scripts/annotate-release.sh
@@ -23,8 +23,8 @@ To download the wheel (by version):
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl .
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl .
-aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl .
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu129/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl .
+aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu130/vllm-${RELEASE_VERSION}+cu130-cp38-abi3-manylinux1_x86_64.whl .
\`\`\`
To download and upload the image:
@@ -45,9 +45,10 @@ docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
docker push vllm/vllm-openai:latest-aarch64
docker push vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
-docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64 --amend
-docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 --amend
+docker manifest rm vllm/vllm-openai:latest
+docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64
+docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
docker manifest push vllm/vllm-openai:latest
docker manifest push vllm/vllm-openai:v${RELEASE_VERSION}
\`\`\`
-EOF
\ No newline at end of file
+EOF
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
new file mode 100755
index 0000000000000..d0036f24c8d04
--- /dev/null
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
@@ -0,0 +1,64 @@
+#!/bin/bash
+
+# This script build the CPU docker image and run the offline inference inside the container.
+# It serves a sanity check for compilation and basic model usage.
+set -ex
+
+# allow to bind to different cores
+CORE_RANGE=${CORE_RANGE:-0-16}
+OMP_CORE_RANGE=${OMP_CORE_RANGE:-0-16}
+NUMA_NODE=${NUMA_NODE:-0}
+
+export CMAKE_BUILD_PARALLEL_LEVEL=32
+
+# Setup cleanup
+remove_docker_container() {
+ set -e;
+ docker rm -f cpu-test-"$NUMA_NODE" || true;
+}
+trap remove_docker_container EXIT
+remove_docker_container
+
+# Try building the docker image
+numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu .
+
+# Run the image, setting --shm-size=4g for tensor parallel.
+docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
+
+function cpu_tests() {
+ set -e
+ export NUMA_NODE=$2
+
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pip list"
+
+ # offline inference
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
+
+ # Run kernel tests
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -x -v -s tests/kernels/test_onednn.py
+ pytest -x -v -s tests/kernels/attention/test_cpu_attn.py"
+
+ # basic online serving
+ docker exec cpu-test-"$NUMA_NODE" bash -c '
+ set -e
+ VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve meta-llama/Llama-3.2-3B-Instruct --max-model-len 2048 &
+ server_pid=$!
+ timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
+ vllm bench serve \
+ --backend vllm \
+ --dataset-name random \
+ --model meta-llama/Llama-3.2-3B-Instruct \
+ --num-prompts 20 \
+ --endpoint /v1/completions
+ kill -s SIGTERM $server_pid &'
+}
+
+# All of CPU tests are expected to be finished less than 40 mins.
+export -f cpu_tests
+timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
index 39ea180173081..3728f73fa2a36 100755
--- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
@@ -25,20 +25,22 @@ function cpu_tests() {
# offline inference
podman exec -it "$container_id" bash -c "
+ export TORCH_COMPILE_DISABLE=1
set -xve
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" >> $HOME/test_basic.log
# Run basic model test
podman exec -it "$container_id" bash -c "
+ export TORCH_COMPILE_DISABLE=1
set -evx
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
- pip install sentence-transformers datamodel_code_generator
+ pip install sentence-transformers datamodel_code_generator tblib
# Note: disable Bart until supports V1
# pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
- pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
- pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
- pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-openai-community/gpt2]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-facebook/opt-125m]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-google/gemma-1.1-2b-it]
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
# TODO: Below test case tests/models/language/pooling/test_embedding.py::test_models[True-ssmits/Qwen2-7B-Instruct-embed-base] fails on ppc64le. Disabling it for time being.
# pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" >> $HOME/test_rest.log
diff --git a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
index 5302f524a0ae4..8106f50f18f66 100644
--- a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
+++ b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
@@ -17,7 +17,17 @@ wait_for_server() {
}
MODEL="deepseek-ai/DeepSeek-V2-lite"
-BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+
+# Set BACKENDS based on platform
+if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
+ # ROCm platform
+ BACKENDS=("allgather_reducescatter")
+ # Disable MOE padding for ROCm since it is causing eplb to fail
+ export VLLM_ROCM_MOE_PADDING=0
+else
+ # Non-ROCm platform (CUDA/other)
+ BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
diff --git a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
similarity index 64%
rename from .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
rename to .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
index a5135299297e2..6a1bef275d047 100644
--- a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
+++ b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
@@ -1,10 +1,12 @@
#!/usr/bin/env bash
set -euxo pipefail
-# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
+# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] [DATA_PARALLEL_SIZE] [TENSOR_PARALLEL_SIZE]
THRESHOLD=${1:-0.8}
NUM_Q=${2:-1319}
PORT=${3:-8020}
+DATA_PARALLEL_SIZE=${4:-2}
+TENSOR_PARALLEL_SIZE=${5:-2}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
@@ -17,7 +19,16 @@ wait_for_server() {
}
MODEL="QWen/Qwen3-30B-A3B-FP8"
-BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+# Set BACKENDS based on platform
+if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
+ # ROCm platform
+ BACKENDS=("allgather_reducescatter")
+ # Disable MOE padding for ROCm since it is causing eplb to fail
+ export VLLM_ROCM_MOE_PADDING=0
+else
+ # Non-ROCm platform (CUDA/other)
+ BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
@@ -36,8 +47,10 @@ for BACK in "${BACKENDS[@]}"; do
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
- --tensor-parallel-size 2 \
- --data-parallel-size 2 \
+ --enable-eplb \
+ --eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
+ --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
+ --data-parallel-size ${DATA_PARALLEL_SIZE} \
--enable-expert-parallel \
--trust-remote-code \
--max-model-len 2048 \
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index 2471b509a9fff..4ddf11c0b268f 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -61,7 +61,7 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_
-- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins
+- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins
timeout_in_minutes: 10
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
@@ -73,6 +73,7 @@ steps:
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
- tests/transformers_utils
+ - tests/config
no_gpu: true
commands:
- python3 standalone_tests/lazy_imports.py
@@ -80,6 +81,7 @@ steps:
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s transformers_utils
+ - pytest -v -s config
- label: Python-only Installation Test # 10min
timeout_in_minutes: 20
@@ -187,7 +189,7 @@ steps:
- tests/distributed/test_utils
- tests/distributed/test_pynccl
- tests/distributed/test_events
- - tests/compile/test_basic_correctness
+ - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
@@ -215,7 +217,7 @@ steps:
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- - pytest -v -s compile/test_basic_correctness.py
+ - pytest -v -s compile/fullgraph/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
@@ -390,6 +392,15 @@ steps:
commands:
- pytest -v -s v1/attention
+- label: V1 Test attention (B200) # 10min
+ timeout_in_minutes: 30
+ gpu: b200
+ source_file_dependencies:
+ - vllm/v1/attention
+ - tests/v1/attention
+ commands:
+ - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
+
- label: V1 Test others (CPU) # 5 mins
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
@@ -493,17 +504,12 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_pass_manager.py
- - pytest -v -s compile/test_fusion.py
- - pytest -v -s compile/test_fusion_attn.py
- - pytest -v -s compile/test_functionalization.py
- - pytest -v -s compile/test_silu_mul_quant_fusion.py
- # - pytest -v -s compile/test_sequence_parallelism.py
- # - pytest -v -s compile/test_async_tp.py
- - pytest -v -s compile/test_fusion_all_reduce.py
- - pytest -v -s compile/test_decorator.py
- - pytest -v -s compile/test_noop_elimination.py
- - pytest -v -s compile/test_aot_compile.py
+ # Run unit tests defined directly under compile/,
+ # not including subdirectories, which are usually heavier
+ # tests covered elsewhere.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
@@ -515,9 +521,11 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_basic_correctness.py
- - pytest -v -s compile/test_multimodal_compile.py
- - pytest -v -s compile/piecewise/
+ # Run smoke tests under fullgraph directory, except test_full_graph.py
+ # as it is a heavy test that is covered in other steps.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Test # 27min
timeout_in_minutes: 40
@@ -529,10 +537,10 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
+ - pytest -v -s compile/fullgraph/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
- - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
+ - "pytest -v -s compile/distributed/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test
timeout_in_minutes: 20
@@ -697,7 +705,7 @@ steps:
- vllm/model_executor/models/whisper.py
commands: # LMEval
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
- - pytest -s entrypoints/openai/correctness/ --ignore entrypoints/openai/correctness/test_transcription_api_correctness.py
+ - pytest -s entrypoints/openai/correctness/
- label: OpenAI-Compatible Tool Use # 23 min
timeout_in_minutes: 35
@@ -746,6 +754,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
+ - vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
@@ -998,12 +1007,12 @@ steps:
optional: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- - pytest -v -s tests/models/test_initialization.py
+ - pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_transformers.py
- - pytest -v -s tests/models/multimodal/processing/
- - pytest -v -s tests/models/multimodal/test_mapping.py
+ # - pytest -v -s tests/models/multimodal/processing/
+ - pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
- python3 examples/offline_inference/basic/chat.py
- - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
+ # - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
@@ -1048,7 +1057,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
-- label: Blackwell Fusion Tests # 30 min
+- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
@@ -1066,10 +1075,12 @@ steps:
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml
- - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
+ # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
+ - pytest -v -s tests/compile/distributed/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40
@@ -1086,20 +1097,18 @@ steps:
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- - tests/compile/test_fusions_e2e.py
- - tests/compile/test_full_graph.py
+ - tests/compile/distributed/test_fusions_e2e.py
+ - tests/compile/fullgraph/test_full_graph.py
commands:
- nvidia-smi
# Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py
- # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: ROCm GPT-OSS Eval
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
agent_pool: mi325_1
- mirror_hardwares: [amdproduction]
+ mirror_hardwares: [amdexperimental, amdproduction]
optional: true # run on nightlies
source_file_dependencies:
- tests/evals/gpt_oss
@@ -1198,7 +1207,7 @@ steps:
- vllm/worker/worker_base.py
- vllm/v1/engine/
- vllm/v1/worker/
- - tests/compile/test_basic_correctness.py
+ - tests/compile/fullgraph/test_basic_correctness.py
- tests/compile/test_wrapper.py
- tests/distributed/
- tests/entrypoints/llm/test_collective_rpc.py
@@ -1211,7 +1220,7 @@ steps:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- - pytest -v -s ./compile/test_basic_correctness.py
+ - pytest -v -s ./compile/fullgraph/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
@@ -1311,7 +1320,10 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- - pytest -v -s -x lora/test_gptoss_tp.py
+
+ # Disabled for now because MXFP4 backend on non-cuda platform
+ # doesn't support LoRA yet
+ #- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min
@@ -1326,7 +1338,7 @@ steps:
- vllm/
- tests/weight_loading
commands:
- - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
+ - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-amd.txt
- label: Weight Loading Multiple GPU Test - Large Models # optional
mirror_hardwares: [amdexperimental]
@@ -1334,13 +1346,12 @@ steps:
# grade: Blocking
working_dir: "/vllm-workspace/tests"
num_gpus: 2
- gpu: a100
optional: true
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
+ - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large-amd.txt
- label: NixlConnector PD accuracy tests (Distributed) # 30min
mirror_hardwares: [amdexperimental]
@@ -1417,10 +1428,12 @@ steps:
working_dir: "/vllm-workspace/"
num_gpus: 2
commands:
- - pytest -v -s tests/compile/test_async_tp.py
- - pytest -v -s tests/compile/test_sequence_parallelism.py
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
- - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
+ - pytest -v -s tests/compile/distributed/test_async_tp.py
+ - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
+ #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
+ - pytest -v -s tests/compile/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py
@@ -1473,4 +1486,4 @@ steps:
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 4ac76aba67b9c..e444becd9867b 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -167,7 +167,7 @@ steps:
- tests/distributed/test_utils
- tests/distributed/test_pynccl
- tests/distributed/test_events
- - tests/compile/test_basic_correctness
+ - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
@@ -192,12 +192,13 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- - pytest -v -s compile/test_basic_correctness.py
+ - pytest -v -s compile/fullgraph/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
@@ -346,6 +347,18 @@ steps:
commands:
- pytest -v -s v1/attention
+- label: Batch Invariance Tests (H100) # 10min
+ timeout_in_minutes: 25
+ gpu: h100
+ source_file_dependencies:
+ - vllm/
+ - tests/v1/determinism/
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pip install pytest-timeout pytest-forked
+ - pytest -v -s v1/determinism/test_batch_invariance.py
+ - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py
+
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200
@@ -445,18 +458,12 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_graph_partition.py
- - pytest -v -s compile/test_config.py
- - pytest -v -s compile/test_pass_manager.py
- - pytest -v -s compile/test_fusion.py
- - pytest -v -s compile/test_fusion_attn.py
- - pytest -v -s compile/test_functionalization.py
- - pytest -v -s compile/test_silu_mul_quant_fusion.py
- - pytest -v -s compile/test_fusion_all_reduce.py
- - pytest -v -s compile/test_decorator.py
- - pytest -v -s compile/test_noop_elimination.py
- - pytest -v -s compile/test_aot_compile.py
- - pytest -v -s compile/test_qk_norm_rope_fusion.py
+ # Run unit tests defined directly under compile/,
+ # not including subdirectories, which are usually heavier
+ # tests covered elsewhere.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
@@ -466,9 +473,11 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_basic_correctness.py
- - pytest -v -s compile/test_multimodal_compile.py
- - pytest -v -s compile/piecewise/
+ # Run smoke tests under fullgraph directory, except test_full_graph.py
+ # as it is a heavy test that is covered in other steps.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Test # 27min
timeout_in_minutes: 40
@@ -479,10 +488,10 @@ steps:
- tests/compile
commands:
# fp8 kv scales not supported on sm89, tested on Blackwell instead
- - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
+ - pytest -v -s compile/fullgraph/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
- - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
+ - "pytest -v -s compile/distributed/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test
timeout_in_minutes: 20
@@ -554,6 +563,25 @@ steps:
commands:
- pytest -v -s kernels/mamba
+- label: Kernels DeepGEMM Test (H100)
+ timeout_in_minutes: 45
+ gpu: h100
+ num_gpus: 1
+ source_file_dependencies:
+ - tools/install_deepgemm.sh
+ - vllm/utils/deep_gemm.py
+ - vllm/model_executor/layers/fused_moe
+ - vllm/model_executor/layers/quantization
+ - tests/kernels/quantization/test_block_fp8.py
+ - tests/kernels/moe/test_deepgemm.py
+ - tests/kernels/moe/test_batched_deepgemm.py
+ - tests/kernels/attention/test_deepgemm_attention.py
+ commands:
+ - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm
+ - pytest -v -s kernels/moe/test_deepgemm.py
+ - pytest -v -s kernels/moe/test_batched_deepgemm.py
+ - pytest -v -s kernels/attention/test_deepgemm_attention.py
+
- label: Model Executor Test # 23min
timeout_in_minutes: 35
torch_nightly: true
@@ -664,6 +692,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
+ - vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
@@ -876,12 +905,12 @@ steps:
optional: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- - pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
+ - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_transformers.py
# - pytest -v -s tests/models/multimodal/processing/
- - pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
+ - pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py
- # - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
+ - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
@@ -925,6 +954,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
+ - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40
@@ -934,22 +964,30 @@ steps:
- csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
+ - vllm/v1/worker/
+ - vllm/v1/cudagraph_dispatcher.py
- vllm/compilation/
# can affect pattern matching
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
+ - vllm/model_executor/layers/fused_moe/layer.py
+ - tests/compile/test_fusion_attn.py
+ - tests/compile/test_silu_mul_quant_fusion.py
+ - tests/compile/distributed/test_fusion_all_reduce.py
+ - tests/compile/distributed/test_fusions_e2e.py
+ - tests/compile/fullgraph/test_full_graph.py
commands:
- nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml
- - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
+ - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40
@@ -966,12 +1004,11 @@ steps:
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- - tests/compile/test_fusions_e2e.py
- - tests/compile/test_full_graph.py
+ - tests/compile/distributed/test_fusions_e2e.py
commands:
- nvidia-smi
# Run all e2e fusion tests
- - pytest -v -s tests/compile/test_fusions_e2e.py
+ - pytest -v -s tests/compile/distributed/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
@@ -1069,7 +1106,7 @@ steps:
- vllm/worker/worker_base.py
- vllm/v1/engine/
- vllm/v1/worker/
- - tests/compile/test_basic_correctness.py
+ - tests/compile/fullgraph/test_basic_correctness.py
- tests/compile/test_wrapper.py
- tests/distributed/
- tests/entrypoints/llm/test_collective_rpc.py
@@ -1081,10 +1118,11 @@ steps:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- - pytest -v -s ./compile/test_basic_correctness.py
+ - pytest -v -s ./compile/fullgraph/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
@@ -1264,10 +1302,10 @@ steps:
working_dir: "/vllm-workspace/"
num_gpus: 2
commands:
- - pytest -v -s tests/compile/test_async_tp.py
- - pytest -v -s tests/compile/test_sequence_parallelism.py
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
- - "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
+ - pytest -v -s tests/compile/distributed/test_async_tp.py
+ - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
- pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
@@ -1305,11 +1343,20 @@ steps:
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010
-- label: Qwen3-30B-A3B-FP8-block Accuracy
+- label: Qwen3-30B-A3B-FP8-block Accuracy (H100)
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
+
+- label: Qwen3-30B-A3B-FP8-block Accuracy (B200)
+ timeout_in_minutes: 60
+ gpu: b200
+ optional: true
+ num_gpus: 2
+ working_dir: "/vllm-workspace"
+ commands:
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
\ No newline at end of file
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 6e178bb690c56..3247408e1163e 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -9,6 +9,7 @@
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/model_loader @22quinn
+/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson
/vllm/lora @jeejeelee
@@ -35,6 +36,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/v1/kv_cache_interface.py @heheda12345
/vllm/v1/offloading @ApostaC
+# Model runner V2
+/vllm/v1/worker/gpu @WoosukKwon
+
# Test ownership
/.buildkite/lm-eval-harness @mgoin
/tests/distributed/test_multi_node_assignment.py @youkaichao
@@ -56,6 +60,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/kv_connector/nixl_integration @NickLucche
/tests/v1/kv_connector @ApostaC
/tests/v1/offloading @ApostaC
+/tests/v1/determinism @yewentao256
# Transformers modeling backend
/vllm/model_executor/models/transformers @hmellor
diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml
index 42b05ecd5ac06..a183033c9adde 100644
--- a/.github/workflows/macos-smoke-test.yml
+++ b/.github/workflows/macos-smoke-test.yml
@@ -9,7 +9,7 @@ on:
jobs:
macos-m1-smoke-test:
runs-on: macos-latest
- timeout-minutes: 20
+ timeout-minutes: 30
steps:
- uses: actions/checkout@v4
@@ -37,15 +37,14 @@ jobs:
- name: Verify installation
run: |
python -c "import vllm; print(f'vLLM version: {vllm.__version__}')"
- python -c "import torch; print(f'PyTorch: {torch.__version__}')"
- name: Smoke test vllm serve
- timeout-minutes: 10
run: |
# Start server in background
vllm serve Qwen/Qwen3-0.6B \
- --max-model-len=2048 \
+ --max-model-len=2K \
--load-format=dummy \
+ --hf-overrides '{"num_hidden_layers": 2}' \
--enforce-eager \
--port 8000 &
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ae8e6175443f3..86746a0db4c0e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -136,7 +136,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
+ Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
@@ -307,7 +307,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
- set(CUTLASS_REVISION "v4.2.1" CACHE STRING "CUTLASS revision to use")
+ set(CUTLASS_REVISION "v4.2.1")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
index 904f805349148..d072c03c440b2 100644
--- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
+++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
@@ -5,11 +5,12 @@ import argparse
import asyncio
import logging
import os
+import time
+import uuid
+from urllib.parse import urlparse
import aiohttp
from quart import Quart, Response, make_response, request
-from rate_limiter import RateLimiter
-from request_queue import RequestQueue
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -24,26 +25,8 @@ def parse_args():
parser.add_argument(
"--timeout",
type=float,
- default=300,
- help="Timeout for backend service requests in seconds (default: 300)",
- )
- parser.add_argument(
- "--max-concurrent",
- type=int,
- default=100,
- help="Maximum concurrent requests to backend services (default: 100)",
- )
- parser.add_argument(
- "--queue-size",
- type=int,
- default=500,
- help="Maximum number of requests in the queue (default: 500)",
- )
- parser.add_argument(
- "--rate-limit",
- type=int,
- default=40,
- help="Maximum requests per second (default: 40)",
+ default=6 * 60 * 60,
+ help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--port",
@@ -54,14 +37,32 @@ def parse_args():
parser.add_argument(
"--prefill-url",
type=str,
- default="http://localhost:8100/v1/completions",
- help="Prefill service endpoint URL",
+ default="http://localhost:8100",
+ help="Prefill service base URL (protocol + host[:port])",
)
parser.add_argument(
"--decode-url",
type=str,
- default="http://localhost:8200/v1/completions",
- help="Decode service endpoint URL",
+ default="http://localhost:8200",
+ help="Decode service base URL (protocol + host[:port])",
+ )
+ parser.add_argument(
+ "--kv-host",
+ type=str,
+ default="localhost",
+ help="Hostname or IP used by KV transfer (default: localhost)",
+ )
+ parser.add_argument(
+ "--prefill-kv-port",
+ type=int,
+ default=14579,
+ help="Prefill KV port (default: 14579)",
+ )
+ parser.add_argument(
+ "--decode-kv-port",
+ type=int,
+ default=14580,
+ help="Decode KV port (default: 14580)",
)
return parser.parse_args()
@@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
- MAX_CONCURRENT_REQUESTS = args.max_concurrent
- REQUEST_QUEUE_SIZE = args.queue_size
- RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url
PORT = args.port
+ PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
+ DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
+
+ logger.info(
+ "Proxy resolved KV addresses -> prefill: %s, decode: %s",
+ PREFILL_KV_ADDR,
+ DECODE_KV_ADDR,
+ )
+
app = Quart(__name__)
- # Initialize the rate limiter and request queue
- rate_limiter = RateLimiter(RATE_LIMIT)
- request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
-
- # Attach the configuration object to the application instance
+ # Attach the configuration object to the application instance so helper
+ # coroutines can read the resolved backend URLs and timeouts without using
+ # globals.
app.config.update(
{
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
- "rate_limiter": rate_limiter,
- "request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
+ "PREFILL_KV_ADDR": PREFILL_KV_ADDR,
+ "DECODE_KV_ADDR": DECODE_KV_ADDR,
}
)
- # Start queue processing on app startup
- @app.before_serving
- async def startup():
- """Start request processing task when app starts serving"""
- asyncio.create_task(request_queue.process())
+ def _normalize_base_url(url: str) -> str:
+ """Remove any trailing slash so path joins behave predictably."""
+ return url.rstrip("/")
- async def forward_request(url, data):
- """Forward request to backend service with rate limiting and error handling"""
- headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+ def _get_host_port(url: str) -> str:
+ """Return the hostname:port portion for logging and KV headers."""
+ parsed = urlparse(url)
+ host = parsed.hostname or "localhost"
+ port = parsed.port
+ if port is None:
+ port = 80 if parsed.scheme == "http" else 443
+ return f"{host}:{port}"
- # Use rate limiter as context manager
- async with (
- rate_limiter,
- aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
- ):
- try:
- async with session.post(
- url=url, json=data, headers=headers
- ) as response:
- if response.status == 200:
- # Stream response chunks
- async for chunk_bytes in response.content.iter_chunked(1024):
- yield chunk_bytes
- else:
- # Handle backend service errors
- error_text = await response.text()
- logger.error(
- "Backend service error: %s - %s",
- response.status,
- error_text,
- )
- yield b'{"error": "Backend service error"}'
- except aiohttp.ClientError as e:
- # Handle connection errors
- logger.error("Connection error to %s: %s", url, str(e))
- yield b'{"error": "Service unavailable"}'
- except asyncio.TimeoutError:
- # Handle timeout errors
- logger.error("Timeout connecting to %s", url)
- yield b'{"error": "Service timeout"}'
+ PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
+ DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
+ KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
+
+ def _build_headers(request_id: str) -> dict[str, str]:
+ """Construct the headers expected by vLLM's P2P disagg connector."""
+ headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+ return headers
+
+ async def _run_prefill(
+ request_path: str,
+ payload: dict,
+ headers: dict[str, str],
+ request_id: str,
+ ):
+ url = f"{PREFILL_BASE}{request_path}"
+ start_ts = time.perf_counter()
+ logger.info("[prefill] start request_id=%s url=%s", request_id, url)
+ try:
+ async with (
+ aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
+ session.post(url=url, json=payload, headers=headers) as resp,
+ ):
+ if resp.status != 200:
+ error_text = await resp.text()
+ raise RuntimeError(
+ f"Prefill backend error {resp.status}: {error_text}"
+ )
+ await resp.read()
+ logger.info(
+ "[prefill] done request_id=%s status=%s elapsed=%.2fs",
+ request_id,
+ resp.status,
+ time.perf_counter() - start_ts,
+ )
+ except asyncio.TimeoutError as exc:
+ raise RuntimeError(f"Prefill service timeout at {url}") from exc
+ except aiohttp.ClientError as exc:
+ raise RuntimeError(f"Prefill service unavailable at {url}") from exc
+
+ async def _stream_decode(
+ request_path: str,
+ payload: dict,
+ headers: dict[str, str],
+ request_id: str,
+ ):
+ url = f"{DECODE_BASE}{request_path}"
+ # Stream tokens from the decode service once the prefill stage has
+ # materialized KV caches on the target workers.
+ logger.info("[decode] start request_id=%s url=%s", request_id, url)
+ try:
+ async with (
+ aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
+ session.post(url=url, json=payload, headers=headers) as resp,
+ ):
+ if resp.status != 200:
+ error_text = await resp.text()
+ logger.error(
+ "Decode backend error %s - %s", resp.status, error_text
+ )
+ err_msg = (
+ '{"error": "Decode backend error ' + str(resp.status) + '"}'
+ )
+ yield err_msg.encode()
+ return
+ logger.info(
+ "[decode] streaming response request_id=%s status=%s",
+ request_id,
+ resp.status,
+ )
+ async for chunk_bytes in resp.content.iter_chunked(1024):
+ yield chunk_bytes
+ logger.info("[decode] finished streaming request_id=%s", request_id)
+ except asyncio.TimeoutError:
+ logger.error("Decode service timeout at %s", url)
+ yield b'{"error": "Decode service timeout"}'
+ except aiohttp.ClientError as exc:
+ logger.error("Decode service error at %s: %s", url, exc)
+ yield b'{"error": "Decode service unavailable"}'
async def process_request():
"""Process a single request through prefill and decode stages"""
@@ -146,13 +206,27 @@ def main():
# Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1
+ if "max_completion_tokens" in prefill_request:
+ prefill_request["max_completion_tokens"] = 1
# Execute prefill stage
- async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
- continue
+ # The request id encodes both KV socket addresses so the backend can
+ # shuttle tensors directly via NCCL once the prefill response
+ # completes.
+ request_id = (
+ f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
+ f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
+ )
+
+ headers = _build_headers(request_id)
+ await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response
- generator = forward_request(DECODE_SERVICE_URL, original_request_data)
+ # Pass the unmodified user request so the decode phase can continue
+ # sampling with the already-populated KV cache.
+ generator = _stream_decode(
+ request.path, original_request_data, headers, request_id
+ )
response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response
return response
@@ -168,23 +242,10 @@ def main():
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting"""
- # Create task for request processing
- task = asyncio.create_task(process_request())
-
- # Enqueue request or reject if queue is full
- if not await request_queue.enqueue(task):
- return Response(
- response=b'{"error": "Server busy, try again later"}',
- status=503,
- content_type="application/json",
- )
-
try:
- # Return the response from the processing task
- return await task
+ return await process_request()
except asyncio.CancelledError:
- # Handle task cancellation (timeout or queue full)
- logger.warning("Request cancelled due to timeout or queue full")
+ logger.warning("Request cancelled")
return Response(
response=b'{"error": "Request cancelled"}',
status=503,
diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py
index cb848d2bf579e..83bd91917508f 100644
--- a/benchmarks/kernels/benchmark_mrope.py
+++ b/benchmarks/kernels/benchmark_mrope.py
@@ -6,7 +6,7 @@
#
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
-# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
+# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
#
@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads: int,
num_kv_heads: int,
max_position: int = 8192,
- rope_theta: float = 10000,
is_neox_style: bool = True,
- rope_scaling: dict[str, Any] = None,
+ rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype = torch.bfloat16,
seed: int = 0,
warmup_iter: int = 10,
@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
- base=rope_theta,
is_neox_style=is_neox_style,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
dtype=dtype,
).to(device=device)
@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads,
head_dim,
max_position,
- rope_theta,
is_neox_style,
- str(rope_scaling),
+ str(rope_parameters),
str(dtype).split(".")[-1],
torch_stats["mean"],
torch_stats["median"],
@@ -255,9 +252,8 @@ if __name__ == "__main__":
"num_kv_heads",
"head_dim",
"max_position",
- "rope_theta",
"is_neox_style",
- "rope_scaling",
+ "rope_parameters",
"dtype",
"torch_mean",
"torch_median",
@@ -303,7 +299,7 @@ if __name__ == "__main__":
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
is_neox_style = True
- rope_theta = config.rope_theta
+ rope_parameters = config.rope_parameters
max_position = config.max_position_embeddings
for num_tokens in num_tokens_list:
@@ -315,9 +311,8 @@ if __name__ == "__main__":
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
- rope_theta=rope_theta,
is_neox_style=is_neox_style,
- rope_scaling=config.rope_scaling,
+ rope_parameters=rope_parameters,
dtype=getattr(torch, args.dtype),
seed=args.seed,
warmup_iter=args.warmup_iter,
diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md
index 41e68e047be82..a28c6956be0e9 100644
--- a/benchmarks/kernels/deepgemm/README.md
+++ b/benchmarks/kernels/deepgemm/README.md
@@ -2,7 +2,7 @@
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
-Currently this just includes dense GEMMs and only works on Hopper GPUs.
+Currently, this just includes dense GEMMs and only works on Hopper GPUs.
## Setup
diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake
index 567c8959f0454..ff687e0af7b44 100644
--- a/cmake/external_projects/vllm_flash_attn.cmake
+++ b/cmake/external_projects/vllm_flash_attn.cmake
@@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
- GIT_TAG 58e0626a692f09241182582659e3bf8f16472659
+ GIT_TAG 86f8f157cf82aa2342743752b97788922dd7de43
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
diff --git a/csrc/cache.h b/csrc/cache.h
index b162a4a2bc31f..f2a5ec0acf5cd 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt);
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 0aa0dc14c7480..8a5457206c706 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -552,7 +552,11 @@ __global__ void indexer_k_quant_and_cache_kernel(
#ifndef USE_ROCM
__syncwarp();
#endif
+#if defined(__gfx942__)
+ float scale = fmaxf(amax, 1e-4) / 224.0f;
+#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
+#endif
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
@@ -901,87 +905,80 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
-template
+template
__global__ void gather_and_maybe_dequant_cache(
- const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
- // ENTRIES...]
- scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
- const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
- const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
- const int32_t block_size, const int32_t entry_size,
+ const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
+ // ENTRIES...]
+ scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
+ const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
+ const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
+ const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
+ const int32_t num_tokens, const int32_t block_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
+ constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
+ using ltype = vllm::vec_n_t;
+ using stype = vllm::vec_n_t;
+ // We are adding this for code readability which will be optimized out when
+ // build in release.
+ assert(CTA_SIZE == blockDim.x);
- const int64_t bid = blockIdx.x; // Batch ID
- const int32_t num_splits = gridDim.y;
- const int32_t split = blockIdx.y;
- const int32_t seq_start = cu_seq_lens[bid];
- const int32_t seq_end = cu_seq_lens[bid + 1];
- const int32_t seq_len = seq_end - seq_start;
- const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
- const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
+#pragma unroll
+ for (int token_id = blockIdx.x; token_id < num_tokens;
+ token_id += gridDim.x) {
+ int64_t batch_id = token_to_seq[token_id];
+ int64_t batch_start = cu_seq_lens[batch_id];
+ int64_t batch_end = cu_seq_lens[batch_id + 1];
+ int32_t batch_offset = token_id - batch_start;
- const int32_t split_start = split * split_blocks;
- const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
+ if (token_id >= batch_end) return;
+ int32_t offset = 0;
+ if (seq_starts != nullptr) {
+ offset = seq_starts[batch_id];
+ }
+ batch_offset += offset;
+ int32_t block_table_id = batch_offset / block_size;
+ int32_t slot_id = batch_offset % block_size;
+ int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
+ int32_t block_id = block_table[block_table_offset];
+ int64_t cache_offset =
+ block_id * cache_block_stride + slot_id * cache_entry_stride;
+ constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
+ scalar_t* dst_ = dst + token_id * dst_entry_stride;
+ cache_t* src_ = const_cast(src_cache) + cache_offset;
- const bool is_active_split = (split_start < tot_blocks);
- const bool is_last_split = (split_end == tot_blocks);
-
- if (!is_active_split) return;
-
- int32_t full_blocks_end = split_end;
- int32_t partial_block_size = 0;
-
- // Adjust the pointer for the block_table for this batch.
- // If seq_starts is provided, compute an offset based on (seq_starts[bid] /
- // page_size)
- const int32_t batch_offset = bid * block_table_stride;
- int32_t offset = 0;
- if (seq_starts != nullptr) {
- offset = seq_starts[bid] / block_size;
- }
- const int32_t* batch_block_table = block_table + batch_offset + offset;
-
- // Adjust dst pointer based on the cumulative sequence lengths.
- dst += seq_start * dst_entry_stride;
-
- if (is_last_split) {
- partial_block_size = seq_len % block_size;
- if (partial_block_size) full_blocks_end -= 1;
- }
-
- auto copy_entry = [&](const cache_t* __restrict__ _src,
- scalar_t* __restrict__ _dst) {
- for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
+#pragma unroll
+ for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
- _dst[i] = static_cast(_src[i]);
+ reinterpret_cast(dst_)[idx] =
+ static_cast(reinterpret_cast(src_)[idx]);
} else {
- _dst[i] =
- fp8::scaled_convert(_src[i], *scale);
+ ltype loaded_val = reinterpret_cast(src_)[idx];
+ stype store_val;
+#pragma unroll
+ for (int j = 0; j < vec_size; ++j) {
+ store_val.val[j] = fp8::scaled_convert(
+ loaded_val.val[j], *scale);
+ }
+ reinterpret_cast(dst_)[idx] = store_val;
}
}
- };
-
- for (int pid = split_start; pid < full_blocks_end; ++pid) {
- auto block_id = batch_block_table[pid];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
- for (int eid = 0; eid < block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
- }
- }
-
- if (partial_block_size) {
- auto block_id = batch_block_table[full_blocks_end];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
- for (int eid = 0; eid < partial_block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
+ // process tail
+ constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
+ dst_ = dst_ + ENTRY_SIZE - tail_cnt;
+ src_ = src_ + ENTRY_SIZE - tail_cnt;
+#pragma unroll
+ for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+ dst_[idx] = static_cast(src_[idx]);
+ } else {
+ dst_[idx] =
+ fp8::scaled_convert(src_[idx], *scale);
+ }
}
}
}
@@ -992,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
-#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
- vllm::gather_and_maybe_dequant_cache \
- <<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst.data_ptr()), \
- block_table.data_ptr(), cu_seq_lens.data_ptr(), \
- block_size, entry_size, block_table_stride, cache_block_stride, \
- cache_entry_stride, dst_entry_stride, \
- reinterpret_cast(scale.data_ptr()), seq_starts_ptr);
+#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
+ vllm::gather_and_maybe_dequant_cache \
+ <<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst.data_ptr()), \
+ block_table.data_ptr(), cu_seq_lens.data_ptr(), \
+ token_to_seq.data_ptr(), num_tokens, block_size, \
+ block_table_stride, cache_block_stride, cache_entry_stride, \
+ dst_entry_stride, reinterpret_cast(scale.data_ptr()), \
+ seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
+// - token_to_seq contains the back mapping from token_id to batch_id
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
- int32_t entry_size = src_cache.flatten(2, -1).size(2);
+ int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
@@ -1029,6 +1030,9 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
+ TORCH_CHECK(head_dim == 576,
+ "gather_and_maybe_dequant_cache only support the head_dim to 576 "
+ "for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
@@ -1046,10 +1050,9 @@ void gather_and_maybe_dequant_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
- // Decide on the number of splits based on the batch size.
- int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
- dim3 grid(batch_size, num_splits);
- dim3 block(1024);
+ constexpr int32_t thread_block_size = 64;
+ dim3 grid(num_tokens);
+ dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr;
diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp
index 50f17c758c148..92f8bee5a47a0 100644
--- a/csrc/cpu/cpu_attn.cpp
+++ b/csrc/cpu/cpu_attn.cpp
@@ -13,6 +13,18 @@
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
+#ifdef __aarch64__
+ #include "cpu_attn_neon.hpp"
+ #define NEON_DISPATCH(...) \
+ case cpu_attention::ISA::NEON: { \
+ using attn_impl = cpu_attention::AttentionImpl; \
+ return __VA_ARGS__(); \
+ }
+#else
+ #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
+#endif // #ifdef __aarch64__
+
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
@@ -41,6 +53,7 @@
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
+ NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl
class AttentionImpl {};
@@ -143,6 +143,12 @@ struct AttentionMetadata {
case ISA::VEC:
ss << "VEC, ";
break;
+ case ISA::VEC16:
+ ss << "VEC16, ";
+ break;
+ case ISA::NEON:
+ ss << "NEON, ";
+ break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
@@ -841,7 +847,7 @@ struct VecTypeTrait {
};
#endif
-#if !defined(__powerpc__)
+#if !defined(__powerpc__) && !defined(__s390x__)
template <>
struct VecTypeTrait {
using vec_t = vec_op::FP16Vec16;
diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp
new file mode 100644
index 0000000000000..827f0cfbc718e
--- /dev/null
+++ b/csrc/cpu/cpu_attn_neon.hpp
@@ -0,0 +1,386 @@
+#ifndef CPU_ATTN_NEON_HPP
+#define CPU_ATTN_NEON_HPP
+
+#include "cpu_attn_impl.hpp"
+#include
+#include
+namespace cpu_attention {
+
+namespace {
+
+#define BLOCK_SIZE_ALIGNMENT 32
+#define HEAD_SIZE_ALIGNMENT 32
+#define MAX_Q_HEAD_NUM_PER_ITER 16
+
+// These do not use vectorized class for loading / converting
+// because csrc/cpu/cpu_types_arm.hpp does not have fallback options
+// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that
+// doesn't support BF16.
+// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency.
+template
+FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0,
+ float32x4_t& b1);
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const float* p, float32x4_t& b0,
+ float32x4_t& b1) {
+ b0 = vld1q_f32(p + 0);
+ b1 = vld1q_f32(p + 4);
+}
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const c10::Half* p,
+ float32x4_t& b0,
+ float32x4_t& b1) {
+ const float16_t* h = reinterpret_cast(p);
+ float16x8_t v = vld1q_f16(h);
+ b0 = vcvt_f32_f16(vget_low_f16(v));
+ b1 = vcvt_f32_f16(vget_high_f16(v));
+}
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p,
+ float32x4_t& b0,
+ float32x4_t& b1) {
+ const uint16_t* u = reinterpret_cast(p);
+#ifdef ARM_BF16_SUPPORT
+ uint16x8_t u0 = vld1q_u16(u);
+ bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0);
+ b0 = vcvtq_low_f32_bf16(bf0);
+ b1 = vcvtq_high_f32_bf16(bf0);
+#else
+ uint16x8_t x0 = vld1q_u16(u);
+ uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16);
+ uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16);
+ b0 = vreinterpretq_f32_u32(lo);
+ b1 = vreinterpretq_f32_u32(hi);
+#endif
+}
+
+// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs
+// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
+// #FMLAs = (K // 4) * (4 * 2 * M)
+// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
+template
+FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4(
+ const float* __restrict A, // [M x K],
+ const kv_cache_t* __restrict B, // [K x 8],
+ float* __restrict C, // [M x 8],
+ int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
+ // kernel supports max M of 8, as it'd spill for larger M
+ static_assert(1 <= M && M <= 8, "M must be in [1,8]");
+
+// helpers for per-M codegen
+#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
+#define IF_M(i) if constexpr (M > (i))
+
+ // A row base pointers
+#define DECL_A(i) const float* a##i = A + (i) * lda;
+ ROWS_APPLY(DECL_A)
+#undef DECL_A
+
+ // declare 2 accumulators per row of M
+#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1;
+ ROWS_APPLY(DECL_ACC)
+#undef DECL_ACC
+
+ // initialize accumulators
+#define INIT_ACC(i) \
+ IF_M(i) { \
+ if (accumulate) { \
+ acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \
+ acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \
+ } else { \
+ acc##i##_0 = vdupq_n_f32(0.f); \
+ acc##i##_1 = vdupq_n_f32(0.f); \
+ } \
+ }
+ ROWS_APPLY(INIT_ACC)
+#undef INIT_ACC
+
+ int32_t k = 0;
+
+ // K unrolled by 4
+ for (; k + 3 < K; k += 4) {
+ // load A[k..k+3] for each active row (M)
+#define LOAD_A4(i) \
+ float32x4_t a##i##v; \
+ IF_M(i) a##i##v = vld1q_f32(a##i + k);
+ ROWS_APPLY(LOAD_A4)
+#undef LOAD_A4
+
+ // helper: FMA lane L from aiv
+#define FMAS_LANE(i, aiv, L) \
+ IF_M(i) { \
+ acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \
+ acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \
+ }
+
+ // k + 0
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb, b0, b1);
+#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
+ ROWS_APPLY(STEP_K0)
+#undef STEP_K0
+ }
+ // k + 1
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb, b0, b1);
+#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
+ ROWS_APPLY(STEP_K1)
+#undef STEP_K1
+ }
+ // k + 2
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb, b0, b1);
+#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
+ ROWS_APPLY(STEP_K2)
+#undef STEP_K2
+ }
+ // k + 3
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb, b0, b1);
+#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
+ ROWS_APPLY(STEP_K3)
+#undef STEP_K3
+ }
+#undef FMAS_LANE
+ }
+
+ // K tail
+ for (; k < K; ++k) {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)k * ldb, b0, b1);
+#define TAIL_ROW(i) \
+ IF_M(i) { \
+ float32x4_t ai = vdupq_n_f32(*(a##i + k)); \
+ acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \
+ acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \
+ }
+ ROWS_APPLY(TAIL_ROW)
+#undef TAIL_ROW
+ }
+
+ // store accumulators to C
+#define STORE_ROW(i) \
+ IF_M(i) { \
+ vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \
+ vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \
+ }
+ ROWS_APPLY(STORE_ROW)
+#undef STORE_ROW
+
+#undef ROWS_APPLY
+#undef IF_M
+}
+
+template
+FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A,
+ const kv_cache_t* __restrict B,
+ float* __restrict C, int32_t M,
+ int32_t K, int64_t lda,
+ int64_t ldb, int64_t ldc,
+ bool accumulate) {
+ // micro kernel is Mx8
+ static_assert(N % 8 == 0, "N must be a multiple of 8");
+ for (int32_t m = 0; m < M;) {
+ int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
+ const float* Ab = A + m * lda;
+ float* Cb = C + m * ldc;
+
+ for (int32_t n = 0; n < N; n += 8) {
+ const kv_cache_t* Bn = B + n;
+ float* Cn = Cb + n;
+ switch (mb) {
+ case 8:
+ gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ case 4:
+ gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ case 2:
+ gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ default:
+ gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ }
+ }
+ // no tail loop for N as it's guaranteed to be a multiple of 8
+ m += mb;
+ }
+}
+
+template
+class TileGemmNeonFMLA {
+ public:
+ template
+ FORCE_INLINE static void gemm(const int32_t m_size,
+ float* __restrict__ a_tile,
+ kv_cache_t* __restrict__ b_tile,
+ float* __restrict__ c_tile, const int64_t lda,
+ const int64_t ldb, const int64_t ldc,
+ const int32_t block_size,
+ const int32_t dynamic_k_size,
+ const bool accum_c) {
+ if constexpr (phase == AttentionGemmPhase::QK) {
+ gemm_macro_neon_fmla_Mx8_Ku4(
+ a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
+ } else {
+ gemm_macro_neon_fmla_Mx8_Ku4(
+ a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
+ accum_c);
+ }
+ }
+};
+
+} // namespace
+
+// this is similar to "ISA::VEC" at the moment
+template
+class AttentionImpl {
+ public:
+ using query_t = scalar_t;
+ using q_buffer_t = float;
+ using kv_cache_t = scalar_t;
+ using logits_buffer_t = float;
+ using partial_output_buffer_t = float;
+ using prob_buffer_t = float;
+
+ constexpr static int64_t BlockSizeAlignment =
+ BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases
+ constexpr static int64_t HeadDimAlignment =
+ HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase
+ constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
+ constexpr static int64_t HeadDim = head_dim;
+ constexpr static ISA ISAType = ISA::NEON;
+ constexpr static bool scale_on_logits = false; // apply scale on q_buffer
+
+ static_assert(HeadDim % HeadDimAlignment == 0);
+ // the gemm micro kernel is Mx8
+ static_assert(HeadDimAlignment % 8 == 0);
+ static_assert(BlockSizeAlignment % 8 == 0);
+
+ public:
+ template typename attention>
+ FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
+ attention> attention_iteration;
+ attention_iteration(CPU_ATTENTION_PARAMS);
+ }
+
+ // k_cache_token_group_stride: stride of K cache when move to next
+ // BlockSizeAlignment tokens in a block
+ constexpr static int64_t k_cache_token_group_stride(
+ const int32_t block_size) {
+ return BlockSizeAlignment; // layout of k_cache block is [head_dim,
+ // block_size], row-major
+ }
+
+ // v_cache_token_group_stride: stride of V cache when move to next
+ // BlockSizeAlignment tokens in a block
+ constexpr static int64_t v_cache_token_group_stride(
+ const int32_t block_size) {
+ return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
+ // head_dim], row-major
+ }
+
+ // v_cache_head_group_stride: stride of V cache when move to next
+ // HeadDimAlignment head dims in a block
+ constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
+ return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
+ // row-major
+ }
+
+ // Copy q to q_buffer and cast it to fp32
+ static void copy_q_heads_tile(
+ scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
+ float* __restrict__ q_buffer, const int32_t q_num,
+ const int32_t q_heads_per_kv, const int64_t q_num_stride,
+ const int64_t q_head_stride, float scale) {
+ static_assert(head_dim % 16 == 0);
+ constexpr int32_t unroll_size = head_dim / 16;
+ using load_vec_t = typename VecTypeTrait::vec_t;
+
+ vec_op::FP32Vec16 scale_vec(scale);
+ for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
+ for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
+ scalar_t* __restrict__ curr_q =
+ src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
+ float* __restrict__ curr_q_buffer =
+ q_buffer + q_num_idx * q_heads_per_kv * head_dim +
+ q_head_idx * head_dim;
+
+ vec_op::unroll_loop([&](int32_t i) {
+ load_vec_t vec(curr_q);
+ vec_op::FP32Vec16 fp32_vec(vec);
+ fp32_vec = fp32_vec * scale_vec;
+ fp32_vec.save(curr_q_buffer);
+
+ curr_q += 16;
+ curr_q_buffer += 16;
+ });
+ }
+ }
+ }
+
+ // reshape K as column-major and V as row-major
+ static void reshape_and_cache(
+ const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
+ scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
+ const int64_t* __restrict__ slot_mapping, const int64_t token_num,
+ const int64_t key_token_num_stride, const int64_t value_token_num_stride,
+ const int64_t head_num, const int64_t key_head_num_stride,
+ const int64_t value_head_num_stride, const int64_t num_blocks,
+ const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
+ const int64_t block_size, const int64_t block_size_stride) {
+#pragma omp parallel for collapse(2)
+ for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
+ for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
+ const int64_t pos = slot_mapping[token_idx];
+ if (pos < 0) {
+ // skip
+ continue;
+ }
+
+ const int64_t block_idx = pos / block_size;
+ const int64_t block_offset = pos % block_size;
+ {
+ // Write Key
+ const scalar_t* key_start_ptr = key +
+ token_idx * key_token_num_stride +
+ head_idx * key_head_num_stride;
+ scalar_t* key_cache_start_ptr =
+ key_cache + block_idx * num_blocks_stride +
+ head_idx * cache_head_num_stride + block_offset;
+
+#pragma GCC unroll 8
+ for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
+ key_cache_start_ptr[j] = key_start_ptr[i];
+ }
+ }
+ {
+ // Write Value
+ const scalar_t* value_start_ptr = value +
+ token_idx * value_token_num_stride +
+ head_idx * value_head_num_stride;
+ scalar_t* value_cache_start_ptr =
+ value_cache + block_idx * num_blocks_stride +
+ head_idx * cache_head_num_stride + block_offset * head_dim;
+ std::memcpy(value_cache_start_ptr, value_start_ptr,
+ sizeof(scalar_t) * head_dim);
+ }
+ }
+ }
+ }
+};
+} // namespace cpu_attention
+
+#endif // #ifndef CPU_ATTN_NEON_HPP
diff --git a/csrc/cpu/cpu_types_scalar.hpp b/csrc/cpu/cpu_types_scalar.hpp
index 1a9278bc662e5..f9da78283da5e 100644
--- a/csrc/cpu/cpu_types_scalar.hpp
+++ b/csrc/cpu/cpu_types_scalar.hpp
@@ -26,10 +26,6 @@ namespace vec_op {
#define FORCE_INLINE __attribute__((always_inline)) inline
-#define __max(a, b) ((a) > (b) ? (a) : (b))
-#define __min(a, b) ((a) < (b) ? (a) : (b))
-#define __abs(a) ((a) < (0) ? (0 - a) : (a))
-
typedef struct f16x8_t {
uint16_t val[8];
} f16x8_t;
@@ -99,7 +95,7 @@ struct FP16Vec16 : public Vec {
void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
- int num = __min(elem_num, VEC_ELEM_NUM);
+ int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
@@ -128,7 +124,7 @@ struct BF16Vec16 : public Vec {
void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
- int num = __min(elem_num, VEC_ELEM_NUM);
+ int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
@@ -143,9 +139,9 @@ struct BF16Vec32 : public Vec {
explicit BF16Vec32(f16x32_t data) : reg(data) {};
explicit BF16Vec32(BF16Vec8& vec8_data) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ unroll_loop([&vec8_data, this](int i) {
reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM];
- }
+ });
}
void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }
@@ -157,15 +153,11 @@ struct FP32Vec4 : public Vec {
f32x4_t reg;
explicit FP32Vec4(float v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = v;
- }
+ unroll_loop([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec4() {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = 0.0f;
- }
+ unroll_loop([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec4(const float* ptr)
@@ -182,15 +174,11 @@ struct FP32Vec8 : public Vec {
f32x8_t reg;
explicit FP32Vec8(float v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = v;
- }
+ unroll_loop([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec8() {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = 0.0f;
- }
+ unroll_loop([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec8(const float* ptr)
@@ -201,78 +189,68 @@ struct FP32Vec8 : public Vec {
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = fp16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
FP32Vec8(const BF16Vec8& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = bf16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
float reduce_sum() const {
float result = 0;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result += reg.val[i];
- }
+ unroll_loop(
+ [&result, this](int i) { result += reg.val[i]; });
return result;
}
FP32Vec8 exp() const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = expf(reg.val[i]);
- }
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = expf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 tanh() const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = tanhf(reg.val[i]);
- }
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 er() const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = erf(reg.val[i]);
- }
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = erf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] * b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator+(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] + b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator-(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] - b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator/(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] / b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
return FP32Vec8(ret);
}
@@ -284,15 +262,11 @@ struct FP32Vec16 : public Vec {
f32x16_t reg;
explicit FP32Vec16(float v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = v;
- }
+ unroll_loop([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec16() {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = 0.0f;
- }
+ unroll_loop([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec16(const float* ptr)
@@ -301,29 +275,27 @@ struct FP32Vec16 : public Vec {
explicit FP32Vec16(f32x16_t data) : reg(data) {};
FP32Vec16(const FP32Vec4& data) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ unroll_loop([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM];
- }
+ });
}
FP32Vec16(const FP32Vec8& data) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ unroll_loop([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM];
- }
+ });
}
FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = fp16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const BF16Vec16& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = bf16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
@@ -331,82 +303,74 @@ struct FP32Vec16 : public Vec {
FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16 operator*(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] * b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 operator+(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] + b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 operator-(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] - b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 operator/(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] / b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 max(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = __max(reg.val[i], b.reg.val[i]);
- }
- return result;
+ f32x16_t ret;
+ unroll_loop([&ret, &b, this](int i) {
+ ret.val[i] = std::max(reg.val[i], b.reg.val[i]);
+ });
+ return FP32Vec16(ret);
}
FP32Vec16 min(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = __min(reg.val[i], b.reg.val[i]);
- }
- return result;
+ f32x16_t ret;
+ unroll_loop([&ret, &b, this](int i) {
+ ret.val[i] = std::min(reg.val[i], b.reg.val[i]);
+ });
+ return FP32Vec16(ret);
}
FP32Vec16 abs() const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = __abs(reg.val[i]);
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); });
+ return FP32Vec16(ret);
}
float reduce_sum() const {
float result = 0.0f;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result += reg.val[i];
- }
+ unroll_loop(
+ [&result, this](int i) { result += reg.val[i]; });
return result;
}
float reduce_max() const {
- float result = reg.val[0];
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result = __max(reg.val[i], result);
- }
+ float result = std::numeric_limits::lowest();
+ unroll_loop(
+ [&result, this](int i) { result = std::max(reg.val[i], result); });
return result;
}
float reduce_min() const {
- float result = reg.val[0];
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result = __min(reg.val[i], result);
- }
+ float result = std::numeric_limits::max();
+ unroll_loop(
+ [&result, this](int i) { result = std::min(reg.val[i], result); });
return result;
}
@@ -414,13 +378,9 @@ struct FP32Vec16 : public Vec {
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
float sum = 0.0;
- int start = idx * group_size;
- int end = (idx + 1) * group_size;
-
- for (; (start < VEC_ELEM_NUM) && (start < end); ++start) {
- sum += reg.val[start];
- }
-
+ const int start = idx * group_size;
+ unroll_loop(
+ [&sum, &start, this](int i) { sum += reg.val[start + i]; });
return sum;
}
@@ -477,17 +437,13 @@ inline void storeFP32(float v, c10::BFloat16* ptr) {
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
- int i = 0;
- for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_fp16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
- int i = 0;
- for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_fp16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
@@ -495,17 +451,13 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
- int i = 0;
- for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_bf16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
- int i = 0;
- for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_bf16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }
diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp
index 51bca37e699b9..9efd8b7ec14a4 100644
--- a/csrc/cpu/cpu_types_vxe.hpp
+++ b/csrc/cpu/cpu_types_vxe.hpp
@@ -4,6 +4,7 @@
#include
#include
+#include
#include
namespace vec_op {
@@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec {
}
explicit FP32Vec8(const BF16Vec8& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
float reduce_sum() const {
@@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec {
}
FP32Vec8 exp() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::exp(ar.values[0]);
- ret.val[0][1] = std::exp(ar.values[1]);
- ret.val[0][2] = std::exp(ar.values[2]);
- ret.val[0][3] = std::exp(ar.values[3]);
- ret.val[1][0] = std::exp(ar.values[4]);
- ret.val[1][1] = std::exp(ar.values[5]);
- ret.val[1][2] = std::exp(ar.values[6]);
- ret.val[1][3] = std::exp(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ f32x4x2_t out;
+
+ const __vector float log2e = vec_splats(1.44269504088896341f);
+ const __vector float one = vec_splats(1.0f);
+ const __vector float min_x = vec_splats(-87.3f);
+ const __vector float max_x = vec_splats(88.7f);
+
+ // 5th-degree minimax polynomial for 2^r (r in [0,1))
+ const __vector float c1 = vec_splats(0.6931471805599453f);
+ const __vector float c2 = vec_splats(0.240226506959101f);
+ const __vector float c3 = vec_splats(0.05550410866482158f);
+ const __vector float c4 = vec_splats(0.009618129107628477f);
+ const __vector float c5 = vec_splats(0.0013333558146428443f);
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+
+ x = vec_max(x, min_x);
+ x = vec_min(x, max_x);
+
+ __vector float y = vec_mul(x, log2e);
+
+ __vector float kf = vec_floor(y);
+ __vector float r = vec_sub(y, kf);
+
+ __vector signed int k = vec_signed(kf);
+ const __vector signed int min_k = vec_splats((signed int)-126);
+ const __vector signed int max_k = vec_splats((signed int)127);
+ k = vec_min(vec_max(k, min_k), max_k);
+
+ // Build 2^k from exponent bits
+ __vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
+ __vector unsigned int bits = (__vector unsigned int)exp_int;
+ bits = vec_sl(bits, vec_splats((unsigned int)23));
+ __vector float pow2k = (__vector float)bits;
+
+ // Improved minimax polynomial
+ __vector float poly = vec_madd(c5, r, c4);
+ poly = vec_madd(poly, r, c3);
+ poly = vec_madd(poly, r, c2);
+ poly = vec_madd(poly, r, c1);
+ poly = vec_madd(poly, r, one);
+
+ out.val[i] = vec_mul(pow2k, poly);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 tanh() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::tanh(ar.values[0]);
- ret.val[0][1] = std::tanh(ar.values[1]);
- ret.val[0][2] = std::tanh(ar.values[2]);
- ret.val[0][3] = std::tanh(ar.values[3]);
- ret.val[1][0] = std::tanh(ar.values[4]);
- ret.val[1][1] = std::tanh(ar.values[5]);
- ret.val[1][2] = std::tanh(ar.values[6]);
- ret.val[1][3] = std::tanh(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
+ const __vector float one = vec_splats(1.0f);
+ const __vector float two = vec_splats(2.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float sat =
+ vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
+
+ f32x4x2_t out;
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+ __vector float ax = vec_abs(x);
+
+ // sign(x): +1 or -1
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // saturation mask: |x| > sat
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // 2x
+ __vector float two_x = vec_mul(x, two);
+
+ // Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
+ f32x4x2_t tmp;
+ tmp.val[0] = two_x;
+ tmp.val[1] = two_x;
+ FP32Vec8 exp_2x_vec(tmp);
+
+ FP32Vec8 e2x = exp_2x_vec.exp();
+ __vector float e = e2x.reg.val[i];
+
+ // tanh(x) = (e - 1) / (e + 1)
+ __vector float num = vec_sub(e, one);
+ __vector float den = vec_add(e, one);
+
+ __vector float t = vec_div(num, den);
+
+ // For large |x|, clamp to sign(x)
+ out.val[i] = vec_sel(t, sign, saturated);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 er() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::erf(ar.values[0]);
- ret.val[0][1] = std::erf(ar.values[1]);
- ret.val[0][2] = std::erf(ar.values[2]);
- ret.val[0][3] = std::erf(ar.values[3]);
- ret.val[1][0] = std::erf(ar.values[4]);
- ret.val[1][1] = std::erf(ar.values[5]);
- ret.val[1][2] = std::erf(ar.values[6]);
- ret.val[1][3] = std::erf(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // A&S 7.1.26 approximation:
+ // erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
+ // exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
+
+ const __vector float one = vec_splats(1.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float p = vec_splats(0.3275911f);
+
+ // Polynomial coeffs
+ const __vector float a1 = vec_splats(0.254829592f);
+ const __vector float a2 = vec_splats(-0.284496736f);
+ const __vector float a3 = vec_splats(1.421413741f);
+ const __vector float a4 = vec_splats(-1.453152027f);
+ const __vector float a5 = vec_splats(1.061405429f);
+
+ // Threshold where erf(x) ~ sign(x)
+ const __vector float sat = vec_splats(6.0f);
+
+ f32x4x2_t out;
+
+ for (int lane = 0; lane < 2; lane++) {
+ __vector float x = reg.val[lane];
+ __vector float ax = vec_abs(x);
+
+ // sign(x)
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // |x| > 6 → erf(x) = ±1
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // t = 1 / (1 + p * |x|)
+ __vector float t = vec_madd(p, ax, one);
+ t = vec_div(one, t);
+
+ // poly = a5
+ __vector float poly = a5;
+ poly = vec_madd(poly, t, a4);
+ poly = vec_madd(poly, t, a3);
+ poly = vec_madd(poly, t, a2);
+ poly = vec_madd(poly, t, a1);
+
+ // full polynomial: poly = poly * t
+ poly = vec_mul(poly, t);
+
+ // Compute exp(-x^2)
+ __vector float x2 = vec_mul(x, x);
+ __vector float neg_x2 = vec_neg(x2);
+
+ f32x4x2_t tmp;
+ tmp.val[0] = neg_x2;
+ tmp.val[1] = neg_x2;
+ FP32Vec8 exp_neg_x2(tmp);
+
+ FP32Vec8 e = exp_neg_x2.exp();
+ __vector float ex = e.reg.val[lane];
+
+ // erf(x) = sign * (1 - poly * exp(-x^2))
+ __vector float term = vec_mul(poly, ex);
+ __vector float y = vec_sub(one, term);
+ y = vec_mul(y, sign);
+
+ // saturated → ±1
+ __vector float sat_val = vec_mul(sign, one);
+ out.val[lane] = vec_sel(y, sat_val, saturated);
+ }
+
+ return FP32Vec8(out);
+ }
+ // Elementwise sigmoid(x) = 1 / (1 + exp(-x))
+ FP32Vec8 sigmoid() const {
+ const __vector float one = vec_splats(1.0f);
+
+ f32x4x2_t neg;
+ for (int i = 0; i < 2; ++i) {
+ neg.val[i] = vec_neg(reg.val[i]);
+ }
+
+ FP32Vec8 neg_x(neg);
+ FP32Vec8 e = neg_x.exp(); // exp(-x)
+
+ f32x4x2_t denom;
+ for (int i = 0; i < 2; ++i) {
+ denom.val[i] = vec_add(one, e.reg.val[i]);
+ }
+
+ FP32Vec8 denom_vec(denom);
+ FP32Vec8 one_vec(1.0f);
+
+ return one_vec / denom_vec;
+ }
+
+ // Tanh-based GELU:
+ // gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
+ FP32Vec8 gelu_tanh() const {
+ const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
+ const __vector float k_0_0447 = vec_splats(0.044715f);
+
+ f32x4x2_t x2, x3, inner;
+ for (int i = 0; i < 2; ++i) {
+ __vector float x = reg.val[i];
+ x2.val[i] = vec_mul(x, x); // x^2
+ x3.val[i] = vec_mul(x2.val[i], x); // x^3
+ __vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
+ inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
+ }
+
+ FP32Vec8 inner_vec(inner);
+ FP32Vec8 t = inner_vec.tanh(); // tanh part
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ FP32Vec8 x_vec(*this);
+ return x_vec * half_vec * (one_vec + t);
+ }
+
+ // Erf-based GELU:
+ // gelu(x) = 0.5 * x * (1 + erf(x / √2))
+ FP32Vec8 gelu_erf() const {
+ const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
+ FP32Vec8 x_vec(*this);
+
+ f32x4x2_t scaled;
+ for (int i = 0; i < 2; ++i) {
+ scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
+ }
+ FP32Vec8 x_scaled(scaled);
+
+ FP32Vec8 erf_x = x_scaled.er();
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ return x_vec * half_vec * (one_vec + erf_x);
+ }
+
+ // Elementwise reciprocal: 1/x (scalar per lane, for correctness)
+ FP32Vec8 rcp() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / in.values[i];
+ }
+ return FP32Vec8(out.reg);
+ }
+
+ // Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
+ FP32Vec8 rsqrt() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / std::sqrt(in.values[i]);
+ }
+ return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
@@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec {
}
explicit FP32Vec16(const BF16Vec16& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
- reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
- reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
+ reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
+ reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
@@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec {
return result;
}
+ FP32Vec16 max(const FP32Vec16& b) const {
+ return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
+ vec_max(reg.val[1], b.reg.val[1]),
+ vec_max(reg.val[2], b.reg.val[2]),
+ vec_max(reg.val[3], b.reg.val[3])}));
+ }
+
+ float reduce_max() const {
+ AliasReg ar;
+ ar.reg = reg;
+ float result = ar.values[0];
+ unroll_loop([&result, &ar](int i) {
+ if (ar.values[i] > result) result = ar.values[i];
+ });
+ return result;
+ }
+
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
@@ -402,15 +628,14 @@ struct VecType {
using vec_type = BF16Vec8;
};
+// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
+using FP16Vec16 = FP32Vec16;
+
template
void storeFP32(float v, T* ptr) {
*ptr = v;
}
-inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
- acc = acc + a * b;
-}
-
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
@@ -429,6 +654,79 @@ inline void storeFP32(float v, c10::BFloat16* ptr) {
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
+// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
+// intrinsics
+
+// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
+FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_madd(a.reg, b.reg, acc.reg);
+}
+
+// FP32Vec8 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+// FP32Vec16 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Multiply-Subtract: acc = acc - (a * b)
+FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_msub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Add: acc = -(a * b) + acc
+FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Subtract: acc = -(a * b) - acc
+FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
@@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
@@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ __vector unsigned int lsb2 = inp2 >> sh16;
+ __vector unsigned int lsb3 = inp3 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ lsb2 = lsb2 & one;
+ lsb3 = lsb3 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ __vector unsigned int rnd2 = lsb2 + bias;
+ __vector unsigned int rnd3 = lsb3 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
+ inp2 = inp2 + rnd2;
+ inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
@@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
- inp2 = vec_sel(inp2, nan, sel2) >> sh16;
- inp3 = vec_sel(inp3, nan, sel3) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp2 = vec_sel(inp2, nan, sel2);
+ inp3 = vec_sel(inp3, nan, sel3);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+ inp2 = inp2 >> sh16;
+ inp3 = inp3 >> sh16;
+
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
-inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
+// 1D softmax over `n` elements in `input`, writes result to `output`.
+// Uses FP32Vec8 for main body, scalar tail handling.
+// Requirement: n > 0
+FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: find max ----------
+ float max_val = -std::numeric_limits::infinity();
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 v(input + i);
+ FP32Vec8::AliasReg ar;
+ ar.reg = v.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ if (ar.values[j] > max_val) max_val = ar.values[j];
+ }
+ }
+ for (; i < n; ++i) {
+ if (input[i] > max_val) max_val = input[i];
+ }
+
+ // ---------- Pass 2: compute exp(x - max) and sum ----------
+ float sum = 0.0f;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = input[i + j] - max_val;
+ }
+
+ FP32Vec8 v(tmp);
+ FP32Vec8 e = v.exp();
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = e.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ output[i + j] = ar.values[j];
+ sum += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float x = input[i] - max_val;
+ float ex = std::exp(x); // scalar tail
+ output[i] = ex;
+ sum += ex;
+ }
+
+ // ---------- Pass 3: normalize ----------
+ float inv_sum = 1.0f / sum;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = output[i + j] * inv_sum;
+ }
+ FP32Vec8 v(tmp);
+ v.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] *= inv_sum;
+ }
+}
+
+// 1D RMSNorm kernel:
+// input: x[0..n-1]
+// weight: w[0..n-1] (gamma), may be nullptr
+// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
+// eps: small epsilon for numerical stability
+FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
+ const float* weight, int n, float eps) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: compute sum of squares ----------
+ float sum_sq = 0.0f;
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ FP32Vec8 sq = x_vec * x_vec;
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = sq.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ sum_sq += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float v = input[i];
+ sum_sq += v * v;
+ }
+
+ float mean_sq = sum_sq / static_cast(n);
+ float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
+
+ // ---------- Pass 2: scale (and apply weight if given) ----------
+ const float inv_rms_f = inv_rms;
+ i = 0;
+
+ if (weight) {
+ // with gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ float wtmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ wtmp[j] = weight[i + j];
+ }
+ FP32Vec8 w_vec(wtmp);
+
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec * w_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f * weight[i];
+ }
+ } else {
+ // without gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f;
+ }
+ }
+}
+
+// Prefetch data to cache for better memory access performance
+FORCE_INLINE void prefetch(const void* addr) {
+ __builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
+}
}; // namespace vec_op
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index b07d20bab7dd9..e0e3ef71b485f 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -172,7 +172,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
defined(__powerpc64__)
- at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
// Helper function to release oneDNN handlers
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
&release_dnnl_matmul_handler);
@@ -208,15 +207,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
- "Tensor? azp) -> ()",
- {stride_tag});
+ "Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
- "Tensor!? azp) -> ()",
- {stride_tag});
+ "Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
#endif
diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp
index c5a48352e3089..5199ba2af024f 100644
--- a/csrc/cpu/utils.cpp
+++ b/csrc/cpu/utils.cpp
@@ -45,31 +45,54 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
- // Verify all CPUs are on the same NUMA node
- for (size_t i = 1; i < omp_cpu_ids.size(); ++i) {
- int node_id = numa_node_of_cpu(omp_cpu_ids[i]);
- TORCH_CHECK(node_id == mem_node_id, "CPU ", omp_cpu_ids[i],
- " is on NUMA node ", node_id, ", but CPU ",
- omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
- ". All CPUs should be on the same NUMA node for optimal "
- "performance. Memory will be bound to NUMA node ",
- mem_node_id, ".");
+ std::set node_ids;
+ for (const auto& cpu_id : omp_cpu_ids) {
+ int node_id = numa_node_of_cpu(cpu_id);
+ if (node_id != -1) {
+ node_ids.insert(node_id);
+ }
+ TORCH_WARN(node_id == mem_node_id, "CPU ", cpu_id, " is on NUMA node ",
+ node_id, ", but CPU ", omp_cpu_ids.front(),
+ " is on NUMA node ", mem_node_id,
+ ". All CPUs should be on the same NUMA node for optimal "
+ "performance. Memory will be bound to NUMA node ",
+ mem_node_id, ".");
}
- bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
- bitmask* src_mask = numa_get_membind();
+ // Concatenate all node_ids into a single comma-separated string
+ if (!node_ids.empty()) {
+ std::string node_ids_str;
+ for (const int node_id : node_ids) {
+ if (!node_ids_str.empty()) {
+ node_ids_str += ",";
+ }
+ node_ids_str += std::to_string(node_id);
+ }
- int pid = getpid();
+ bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
+ bitmask* src_mask = numa_get_membind();
- // move all existing pages to the specified numa node.
- *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
- int page_num = numa_migrate_pages(pid, src_mask, mask);
- if (page_num == -1) {
- TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno));
+ int pid = getpid();
+
+ if (mask && src_mask) {
+ // move all existing pages to the specified numa node.
+ *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
+ int page_num = numa_migrate_pages(pid, src_mask, mask);
+ if (page_num == -1) {
+ TORCH_WARN("numa_migrate_pages failed. errno: " +
+ std::to_string(errno));
+ }
+
+ // restrict memory allocation node.
+ numa_set_membind(mask);
+ numa_set_strict(1);
+
+ numa_free_nodemask(mask);
+ numa_free_nodemask(src_mask);
+ } else {
+ TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " +
+ std::to_string(errno));
+ }
}
-
- // restrict memory allocation node.
- numa_set_membind(mask);
- numa_set_strict(1);
}
// OMP threads binding
diff --git a/csrc/cpu/utils.hpp b/csrc/cpu/utils.hpp
index d8399c56f6af8..d3def306b8069 100644
--- a/csrc/cpu/utils.hpp
+++ b/csrc/cpu/utils.hpp
@@ -6,6 +6,10 @@
#include
#include
+#if defined(__APPLE__)
+ #include
+#endif
+
#include "cpu_types.hpp"
namespace cpu_utils {
@@ -21,10 +25,12 @@ struct VecTypeTrait {
using vec_t = vec_op::FP32Vec16;
};
+#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct VecTypeTrait {
using vec_t = vec_op::BF16Vec16;
};
+#endif
template <>
struct VecTypeTrait {
@@ -44,9 +50,21 @@ struct Counter {
inline int64_t get_l2_size() {
static int64_t size = []() {
+#if defined(__APPLE__)
+ // macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
+ int64_t l2_cache_size = 0;
+ size_t len = sizeof(l2_cache_size);
+ if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
+ l2_cache_size > 0) {
+ return l2_cache_size >> 1; // use 50% of L2 cache
+ }
+ // Fallback if sysctlbyname fails
+ return 128LL * 1024 >> 1; // use 50% of 128KB
+#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
return l2_cache_size >> 1; // use 50% of L2 cache
+#endif
}();
return size;
}
diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu
index 938bd4ab7fc62..9853fc942bab7 100644
--- a/csrc/cuda_view.cu
+++ b/csrc/cuda_view.cu
@@ -22,15 +22,10 @@ torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
auto strides = cpu_tensor.strides();
auto options = cpu_tensor.options().device(torch::kCUDA);
- // from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
- // const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
- // memory, so we don't free it here.
- auto deleter = [](void*) {
- // no-op, since the memory is owned by the original CPU tensor
- };
-
+ // use default no-op deleter, since the memory is owned by the original CPU
+ // tensor
torch::Tensor cuda_tensor =
- torch::from_blob(device_ptr, sizes, strides, deleter, options);
+ torch::from_blob(device_ptr, sizes, strides, options);
TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h
index 9ae0ed975edde..e1d131e4a7851 100644
--- a/csrc/dispatch_utils.h
+++ b/csrc/dispatch_utils.h
@@ -117,3 +117,24 @@
break; \
} \
}
+
+#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
+ switch (NUM_DIMS) { \
+ case 2: { \
+ constexpr int tensor_rank = 2; \
+ __VA_ARGS__(); \
+ break; \
+ } \
+ case 3: { \
+ constexpr int tensor_rank = 3; \
+ __VA_ARGS__(); \
+ break; \
+ } \
+ case 4: { \
+ constexpr int tensor_rank = 4; \
+ __VA_ARGS__(); \
+ break; \
+ } \
+ default: \
+ TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
+ }
diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu
index 48771e4b3aff9..dfc67b933ccae 100644
--- a/csrc/layernorm_kernels.cu
+++ b/csrc/layernorm_kernels.cu
@@ -10,16 +10,38 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
-template
+template
__global__ void rms_norm_kernel(
- scalar_t* __restrict__ out, // [..., hidden_size]
- const scalar_t* __restrict__ input, // [..., hidden_size]
- const int64_t input_stride,
+ scalar_t* __restrict__ out, // [..., hidden_size]
+ const scalar_t* __restrict__ input, // [..., hidden_size]
+ const int64_t input_stride_d2, // input.stride(-2)
+ const int64_t input_stride_d3, // input.stride(-3)
+ const int64_t input_stride_d4, // input.stride(-4)
+ const int64_t input_shape_d2, // input.size(-2)
+ const int64_t input_shape_d3, // input.size(-3)
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
- const scalar_t* input_row = input + blockIdx.x * input_stride;
+ const scalar_t* input_row;
+ if constexpr (NUM_DIMS == 2) {
+ // 2D for layernorm normal case [batch_size, hidden]
+ input_row = input + blockIdx.x * input_stride_d2;
+ } else if constexpr (NUM_DIMS == 3) {
+ // 3D for q/k norm [batch_size, num_heads, head_size]
+ int batch_idx = blockIdx.x / input_shape_d2;
+ int head_idx = blockIdx.x % input_shape_d2;
+ input_row =
+ input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
+ } else if constexpr (NUM_DIMS == 4) {
+ // 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
+ int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2);
+ int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2);
+ int seq_idx = remaining / input_shape_d2;
+ int head_idx = remaining % input_shape_d2;
+ input_row = input + batch_idx * input_stride_d4 +
+ seq_idx * input_stride_d3 + head_idx * input_stride_d2;
+ }
auto vec_op = [&variance](const vec_n_t& vec) {
#pragma unroll
@@ -164,38 +186,44 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
+ if (input.stride(-1) != 1) {
+ input = input.contiguous();
+ }
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
- // We cannot just use `input.stride(-2)` if the tensor is not row-major.
- // Instead, we use a 2d view to get the second-innermost stride.
- // That way the dimensions (except the last one) can be arbitrarily permuted.
- torch::Tensor input_view = input.view({-1, hidden_size});
-
- int num_tokens = input_view.numel() / hidden_size;
- int64_t input_stride = input_view.stride(-2);
+ int num_tokens = input.numel() / hidden_size;
+ int num_dims = input.dim();
+ int64_t input_stride_d2 = input.stride(-2);
+ int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
+ int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
+ int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
+ int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- VLLM_DISPATCH_FLOATING_TYPES(
- input_view.scalar_type(), "rms_norm_kernel", [&] {
- const int calculated_vec_size =
- std::gcd(16 / sizeof(scalar_t), hidden_size);
- const int block_size =
- std::min(hidden_size / calculated_vec_size, max_block_size);
- dim3 block(block_size);
- VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
- vllm::rms_norm_kernel<<>>(
- out.data_ptr(), input_view.data_ptr(),
- input_stride, weight.data_ptr(), epsilon, num_tokens,
- hidden_size);
- });
+ VLLM_DISPATCH_RANK234(num_dims, [&] {
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
+ const int calculated_vec_size =
+ std::gcd(16 / sizeof(scalar_t), hidden_size);
+ const int block_size =
+ std::min(hidden_size / calculated_vec_size, max_block_size);
+ dim3 block(block_size);
+ VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
+ vllm::rms_norm_kernel
+ <<>>(
+ out.data_ptr(), input.data_ptr(),
+ input_stride_d2, input_stride_d3, input_stride_d4,
+ input_shape_d2, input_shape_d3, weight.data_ptr(),
+ epsilon, num_tokens, hidden_size);
});
+ });
+ });
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index c3ae06a30e3e8..14913bef13125 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -20,18 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
- // The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
- // so we need
- // to override this for many GEMMs with the following tag. Otherwise,
- // torch.compile will force all input tensors to be contiguous(), which
- // will break many custom ops that require column-major weight matrices.
- // This was a bug and PyTorch 2.7 has since fixed this.
-#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
- #define stride_tag at::Tag::needs_fixed_stride_order
-#else
- #define stride_tag
-#endif
-
ops.def(
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s,"
@@ -241,15 +229,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
- "Tensor _zeros, SymInt split_k_iters) -> Tensor",
- {stride_tag});
+ "Tensor _zeros, SymInt split_k_iters) -> Tensor");
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
- "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
- {stride_tag});
+ "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Note about marlin kernel 'workspace' arguments:
@@ -271,8 +257,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"int b_q_type, "
- "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
- {stride_tag});
+ "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -298,8 +283,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? channel_scales,"
" Tensor? token_scales,"
" str? schedule"
- ") -> Tensor",
- {stride_tag});
+ ") -> Tensor");
ops.def(
"machete_prepack_B("
" Tensor B,"
@@ -319,8 +303,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
- "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
- {stride_tag});
+ "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
@@ -346,8 +329,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
- ") -> Tensor",
- {stride_tag});
+ ") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
@@ -394,24 +376,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
- " Tensor alpha) -> ()",
- {stride_tag});
+ " Tensor alpha) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass blockwise scaledgroup GEMM
ops.def(
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
"Tensor scales_a, Tensor scales_b, "
- "Tensor problem_sizes, Tensor expert_offsets) -> ()",
- {stride_tag});
+ "Tensor problem_sizes, Tensor expert_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
- " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
- {stride_tag});
+ " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@@ -419,8 +398,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
- " Tensor b_scales, Tensor? bias) -> ()",
- {stride_tag});
+ " Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@@ -429,8 +407,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
- " Tensor? azp, Tensor? bias) -> ()",
- {stride_tag});
+ " Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
@@ -449,8 +426,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
- " bool per_out_ch) -> ()",
- {stride_tag});
+ " bool per_out_ch) -> ()");
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
// A function that computes data required to run fused MoE with w8a8 grouped
@@ -464,8 +440,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
- " int n, int k, Tensor? blockscale_offsets) -> ()",
- {stride_tag});
+ " int n, int k, Tensor? blockscale_offsets) -> "
+ "()");
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes problem sizes for each expert's multiplication
@@ -476,8 +452,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int num_experts, int n, int k, "
- " Tensor? blockscale_offsets) -> ()",
- {stride_tag});
+ " Tensor? blockscale_offsets) -> ()");
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes);
@@ -492,8 +467,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
- " int n, int k) -> ()",
- {stride_tag});
+ " int n, int k) -> ()");
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
&get_cutlass_pplx_moe_mm_data);
@@ -517,8 +491,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
- " Tensor b_scales, Tensor? bias) -> ()",
- {stride_tag});
+ " Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor
@@ -567,8 +540,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
- "-> Tensor",
- {stride_tag});
+ "-> Tensor");
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.
@@ -723,7 +695,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
- " int batch_size, "
+ " Tensor token_to_seq, "
+ " int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 964700e2a43ac..84a1802dbe03a 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -20,8 +20,8 @@ ARG PYTHON_VERSION=3.12
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
-# TODO: Restore to base image after FlashInfer AOT wheel fixed
-ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
+# Using cuda base image with minimal dependencies necessary for JIT compilation (FlashInfer, DeepGEMM, EP kernels)
+ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04
# By parameterizing the Deadsnakes repository URL, we allow third-party to use
# their own mirror. When doing so, we don't benefit from the transparent
@@ -56,7 +56,6 @@ ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
# PyTorch provides its own indexes for standard and nightly builds
ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl
-ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
# PIP supports multiple authentication schemes, including keyring
# By parameterizing the PIP_KEYRING_PROVIDER variable and setting it to
@@ -86,7 +85,20 @@ ARG GET_PIP_URL
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
- && apt-get install -y ccache software-properties-common git curl sudo python3-pip \
+ && apt-get install -y --no-install-recommends \
+ ccache \
+ software-properties-common \
+ git \
+ curl \
+ sudo \
+ python3-pip \
+ libibverbs-dev \
+ # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
+ # as it was causing spam when compiling the CUTLASS kernels
+ gcc-10 \
+ g++-10 \
+ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 \
+ && rm -rf /var/lib/apt/lists/* \
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
@@ -98,7 +110,6 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
ARG PIP_INDEX_URL UV_INDEX_URL
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
ARG PYTORCH_CUDA_INDEX_BASE_URL
-ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
# Activate virtual environment and add uv to PATH
@@ -112,10 +123,6 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
-# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
-# as it was causing spam when compiling the CUTLASS kernels
-RUN apt-get install -y gcc-10 g++-10
-RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
RUN </dev/null 2>&1; then \
+ uv pip install --system /tmp/deepgemm/dist/*.whl; \
+ else \
+ echo "No DeepGEMM wheels to install; skipping."; \
+ fi'
+
+# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH (https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/.ci/manywheel/build_cuda.sh#L141C14-L141C36)
+ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+
+# Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage
+RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm-workspace/ep_kernels/dist \
+ --mount=type=cache,target=/root/.cache/uv \
+ uv pip install --system ep_kernels/dist/*.whl --verbose \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
-# Install DeepGEMM from source
-ARG DEEPGEMM_GIT_REF
-COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
-RUN --mount=type=cache,target=/root/.cache/uv \
- VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"}
-
-COPY tools/install_gdrcopy.sh install_gdrcopy.sh
-RUN set -eux; \
+RUN --mount=type=bind,source=tools/install_gdrcopy.sh,target=/tmp/install_gdrcopy.sh,ro \
+ set -eux; \
case "${TARGETPLATFORM}" in \
linux/arm64) UUARCH="aarch64" ;; \
linux/amd64) UUARCH="x64" ;; \
*) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \
esac; \
- ./install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"; \
- rm ./install_gdrcopy.sh
-
-# Install EP kernels(pplx-kernels and DeepEP)
-COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
-ENV CUDA_HOME=/usr/local/cuda
-RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a 10.0a+PTX}" \
- && bash install_python_libraries.sh
+ /tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"
# CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will
# return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers
@@ -432,6 +460,11 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
+RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
+ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
+ && apt-get update -y \
+ && apt-get install -y git
+
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
@@ -472,12 +505,11 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
-COPY requirements/kv_connectors.txt requirements/kv_connectors.txt
-
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/uv \
+ --mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
- uv pip install --system -r requirements/kv_connectors.txt; \
+ uv pip install --system -r /tmp/kv_connectors.txt; \
fi; \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
BITSANDBYTES_VERSION="0.42.0"; \
diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu
index 4c961defaeda2..eb3807ef0ca4e 100644
--- a/docker/Dockerfile.cpu
+++ b/docker/Dockerfile.cpu
@@ -37,6 +37,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
+ENV CC=/usr/bin/gcc-12 CXX=/usr/bin/g++-12
ENV CCACHE_DIR=/root/.cache/ccache
ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
@@ -122,6 +123,15 @@ WORKDIR /workspace/vllm
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/cpu-test.in && \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
+ remove_packages_not_supported_on_aarch64() { \
+ case "$(uname -m)" in \
+ aarch64|arm64) \
+ sed -i '/decord/d' requirements/cpu-test.in; \
+ sed -i '/terratorch/d' requirements/cpu-test.in; \
+ ;; \
+ esac; \
+ }; \
+ remove_packages_not_supported_on_aarch64 && \
sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch
index b88b9c4992200..d663c82c3885e 100644
--- a/docker/Dockerfile.nightly_torch
+++ b/docker/Dockerfile.nightly_torch
@@ -76,34 +76,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/common.txt
-# must put before installing xformers, so it can install the correct version of xfomrers.
-ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
-ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
-
-# Build xformers with cuda and torch nightly
-# following official xformers guidance: https://github.com/facebookresearch/xformers#build
-# todo(elainewy): cache xformers build result for faster build
-ARG max_jobs=16
-ENV MAX_JOBS=${max_jobs}
-ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
-
-ENV CCACHE_DIR=/root/.cache/ccache
-RUN --mount=type=cache,target=/root/.cache/ccache \
- --mount=type=cache,target=/root/.cache/uv \
- echo 'git clone xformers...' \
- && git clone https://github.com/facebookresearch/xformers.git --recursive \
- && cd xformers \
- && git checkout ${XFORMERS_COMMIT} \
- && git submodule update --init --recursive \
- && echo 'finish git clone xformers...' \
- && rm -rf build \
- && python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
- && cd .. \
- && rm -rf xformers
-
-RUN --mount=type=cache,target=/root/.cache/uv \
- uv pip install --system xformers-dist/*.whl --verbose
-
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
@@ -233,11 +205,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system vllm-dist/*.whl --verbose
-# install xformers again for the new environment
-RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
- --mount=type=cache,target=/root/.cache/uv \
- uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
-
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
# install package for build flashinfer
@@ -307,7 +274,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
# Logging to confirm the torch versions
-RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
+RUN pip freeze | grep -E 'torch|vllm|flashinfer'
# Logging to confirm all the packages are installed
RUN pip freeze
diff --git a/docker/Dockerfile.ppc64le b/docker/Dockerfile.ppc64le
index ad9eae94b83dd..b16bea3607d2f 100644
--- a/docker/Dockerfile.ppc64le
+++ b/docker/Dockerfile.ppc64le
@@ -8,8 +8,8 @@ FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS openbl
ARG MAX_JOBS
ARG OPENBLAS_VERSION=0.3.30
-RUN microdnf install -y dnf && dnf install -y gcc-toolset-13 make wget unzip \
- && source /opt/rh/gcc-toolset-13/enable \
+RUN microdnf install -y dnf && dnf install -y gcc-toolset-14 make wget unzip \
+ && source /opt/rh/gcc-toolset-14/enable \
&& wget https://github.com/OpenMathLib/OpenBLAS/releases/download/v$OPENBLAS_VERSION/OpenBLAS-$OPENBLAS_VERSION.zip \
&& unzip OpenBLAS-$OPENBLAS_VERSION.zip \
&& cd OpenBLAS-$OPENBLAS_VERSION \
@@ -57,7 +57,7 @@ COPY --from=openblas-builder /tmp/control /dev/null
RUN --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/,target=/openblas/,rw \
dnf install -y openssl-devel \
&& dnf install -y \
- git tar gcc-toolset-13 automake libtool \
+ git tar gcc-toolset-14 automake libtool \
pkgconfig xsimd zeromq-devel kmod findutils protobuf* \
libtiff-devel libjpeg-devel zlib-devel freetype-devel libwebp-devel \
harfbuzz-devel libraqm-devel libimagequant-devel libxcb-devel \
@@ -84,7 +84,7 @@ ARG _GLIBCXX_USE_CXX11_ABI=1
ARG OPENBLAS_VERSION=0.3.30
RUN --mount=type=cache,target=/root/.cache/uv \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
git clone --recursive https://github.com/pytorch/pytorch.git -b v${TORCH_VERSION} && \
cd pytorch && \
uv pip install -r requirements.txt && \
@@ -97,7 +97,7 @@ ARG TORCHVISION_VERSION=0.22.0
ARG TORCHVISION_USE_NVJPEG=0
ARG TORCHVISION_USE_FFMPEG=0
RUN --mount=type=cache,target=/root/.cache/uv \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
git clone --recursive https://github.com/pytorch/vision.git -b v${TORCHVISION_VERSION} && \
cd vision && \
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
@@ -113,7 +113,7 @@ ARG USE_ROCM=0
ARG USE_CUDA=0
ARG TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_FFMPEG=1
RUN --mount=type=cache,target=/root/.cache/uv \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
git clone --recursive https://github.com/pytorch/audio.git -b v${TORCHAUDIO_VERSION} && \
cd audio && \
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
@@ -130,7 +130,7 @@ ARG MAX_JOBS
ARG PYARROW_PARALLEL
ARG PYARROW_VERSION=21.0.0
RUN --mount=type=cache,target=/root/.cache/uv \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
git clone --recursive https://github.com/apache/arrow.git -b apache-arrow-${PYARROW_VERSION} && \
cd arrow/cpp && \
mkdir build && cd build && \
@@ -162,7 +162,7 @@ ARG OPENCV_VERSION=86
ARG OPENCV_PATCH=97f3f39
ARG ENABLE_HEADLESS=1
RUN --mount=type=cache,target=/root/.cache/uv \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
cd opencv-python && \
sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \
@@ -196,7 +196,7 @@ ARG MAX_JOBS
ARG NUMBA_VERSION=0.61.2
# Clone all required dependencies
-RUN dnf install ninja-build llvm15 llvm15-devel -y && source /opt/rh/gcc-toolset-13/enable && export PATH=$PATH:/usr/lib64/llvm15/bin && \
+RUN dnf install ninja-build llvm15 llvm15-devel -y && source /opt/rh/gcc-toolset-14/enable && export PATH=$PATH:/usr/lib64/llvm15/bin && \
git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \
cd ./numba && \
if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \
@@ -211,6 +211,9 @@ RUN dnf install ninja-build llvm15 llvm15-devel -y && source /opt/rh/gcc-toolset
FROM base-builder AS vllmcache-builder
+ENV LLVM_CONFIG=/usr/lib64/llvm15/bin/llvm-config
+ENV PATH=/usr/lib64/llvm15/bin:$PATH
+
COPY --from=torch-builder /tmp/control /dev/null
COPY --from=arrow-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
@@ -225,10 +228,13 @@ ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
RUN --mount=type=cache,target=/root/.cache/uv \
dnf install llvm15 llvm15-devel -y && \
rpm -ivh --nodeps https://mirror.stream.centos.org/9-stream/CRB/ppc64le/os/Packages/protobuf-lite-devel-3.14.0-16.el9.ppc64le.rpm && \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \
uv pip install maturin && \
uv build --wheel --out-dir /hf_wheels/
+
+ENV CXXFLAGS="-fno-lto -Wno-error=free-nonheap-object" \
+ CFLAGS="-fno-lto"
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
@@ -236,7 +242,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \
--mount=type=bind,from=numba-builder,source=/numbawheels/,target=/numbawheels/,ro \
--mount=type=bind,src=.,dst=/src/,rw \
- source /opt/rh/gcc-toolset-13/enable && \
+ source /opt/rh/gcc-toolset-14/enable && \
export PATH=$PATH:/usr/lib64/llvm15/bin && \
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /numbawheels/*.whl && \
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
@@ -260,7 +266,7 @@ FROM base-builder AS lapack-builder
ARG MAX_JOBS
ARG LAPACK_VERSION=3.12.1
RUN git clone --recursive https://github.com/Reference-LAPACK/lapack.git -b v${LAPACK_VERSION} \
- && cd lapack && source /opt/rh/gcc-toolset-13/enable \
+ && cd lapack && source /opt/rh/gcc-toolset-14/enable \
&& cmake -B build -S . \
&& cmake --build build -j ${MAX_JOBS:-$(nproc)}
@@ -299,7 +305,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/,target=/openblas/,rw \
rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
microdnf install --nodocs -y \
- libomp tar findutils openssl llvm15 llvm15-devel \
+ libomp libicu tar findutils openssl llvm15 llvm15-devel \
pkgconfig xsimd g++ gcc-fortran libsndfile \
libtiff libjpeg openjpeg2 zlib zeromq \
freetype lcms2 libwebp tcl tk utf8proc \
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm
index 731a97d93da1f..42466d1801cf6 100644
--- a/docker/Dockerfile.rocm
+++ b/docker/Dockerfile.rocm
@@ -7,6 +7,8 @@ FROM ${BASE_IMAGE} AS base
ARG ARG_PYTORCH_ROCM_ARCH
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
+ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
+ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1
# Install some basic utilities
RUN apt-get update -q -y && apt-get install -q -y \
@@ -121,8 +123,6 @@ COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker
-ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
-ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1
ENV TOKENIZERS_PARALLELISM=false
# ENV that can improve safe tensor loading, and end-to-end time
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 5d5b82c4fa5af..adac43c6accbe 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -1,4 +1,4 @@
-FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base
+FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
@@ -25,10 +25,14 @@ RUN apt clean && apt-get update -y && \
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1
-RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing
+RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc
+
+# This oneccl contains the BMG support which is not the case for default version of oneapi 2025.2.
+RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.6/intel-oneccl-2021.15.6.9_offline.sh
+RUN bash intel-oneccl-2021.15.6.9_offline.sh -a --silent --eula accept && \
+ echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc && \
+ echo "source /opt/intel/oneapi/ccl/2021.15/env/vars.sh --force" >> /root/.bashrc
-RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh
-RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc
SHELL ["bash", "-c"]
CMD ["bash", "-c", "source /root/.bashrc && exec bash"]
@@ -72,6 +76,7 @@ RUN python3 -m pip install -e tests/vllm_test_utils
ENV NIXL_VERSION=0.7.0
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
+# remove torch bundled oneccl to avoid conflicts
RUN --mount=type=cache,target=/root/.cache/pip \
pip uninstall oneccl oneccl-devel -y
diff --git a/docs/.nav.yml b/docs/.nav.yml
index 3151ea0e2ec22..c8bf00efb2370 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -24,14 +24,16 @@ nav:
- deployment/integrations
- Training: training
- Configuration:
- - configuration/README.md
- configuration/*
+ - TPU: https://docs.vllm.ai/projects/tpu/en/latest/
- Models:
- models/supported_models.md
- models/generative_models.md
- models/pooling_models.md
- models/extensions
- - Hardware Supported Models: models/hardware_supported_models
+ - Hardware Supported Models:
+ - models/hardware_supported_models/*
+ - TPU: https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/
- Features: features
- Developer Guide:
- contributing/README.md
diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png
index f8c104ba14259..b327eb2151f50 100644
Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ
diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md
index 5ce43c7984057..0aa89a89eae5c 100644
--- a/docs/configuration/conserving_memory.md
+++ b/docs/configuration/conserving_memory.md
@@ -49,9 +49,6 @@ llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.
-!!! warning
- CUDA graph capture takes up more memory in V1 than in V0.
-
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
??? code
diff --git a/docs/configuration/env_vars.md b/docs/configuration/env_vars.md
index 2c0a898754fa0..f6d548a19d91f 100644
--- a/docs/configuration/env_vars.md
+++ b/docs/configuration/env_vars.md
@@ -7,8 +7,6 @@ vLLM uses the following environment variables to configure the system:
All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables).
-??? code
-
- ```python
- --8<-- "vllm/envs.py:env-vars-definition"
- ```
+```python
+--8<-- "vllm/envs.py:env-vars-definition"
+```
diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md
index b0d390d7e1cbb..fdd9c317b022f 100644
--- a/docs/configuration/optimization.md
+++ b/docs/configuration/optimization.md
@@ -31,9 +31,7 @@ In vLLM V1, the default preemption mode is `RECOMPUTE` rather than `SWAP`, as re
Chunked prefill allows vLLM to process large prefills in smaller chunks and batch them together with decode requests. This feature helps improve both throughput and latency by better balancing compute-bound (prefill) and memory-bound (decode) operations.
-In vLLM V1, **chunked prefill is always enabled by default**. This is different from vLLM V0, where it was conditionally enabled based on model characteristics.
-
-With chunked prefill enabled, the scheduling policy prioritizes decode requests. It batches all pending decode requests before scheduling any prefill operations. When there are available tokens in the `max_num_batched_tokens` budget, it schedules pending prefills. If a pending prefill request cannot fit into `max_num_batched_tokens`, it automatically chunks it.
+In V1, **chunked prefill is enabled by default whenever possible**. With chunked prefill enabled, the scheduling policy prioritizes decode requests. It batches all pending decode requests before scheduling any prefill operations. When there are available tokens in the `max_num_batched_tokens` budget, it schedules pending prefills. If a pending prefill request cannot fit into `max_num_batched_tokens`, it automatically chunks it.
This policy has two benefits:
diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md
deleted file mode 100644
index 2d24c9c6e2e95..0000000000000
--- a/docs/configuration/tpu.md
+++ /dev/null
@@ -1,111 +0,0 @@
-# TPU Optimization Tips
-
-This doc serves as a collection of handy tips for optimizing your vLLM on TPU workload.
-
-## Get started
-
-Looking for setup and installation instructions? Find them [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/).
-
-### TPU workload sizing
-
-When selecting the ideal number of chips for a single serving instance, it's important to account for both the model size and the average request context length. Adequate HBM for the KV cache is essential to ensure a sufficient number of concurrent requests can be processed.
-
-The following colab [calculator](https://colab.research.google.com/github/ericehanley/rightsize-vllm/blob/main/HBM_Calculator.ipynb) will tell you:
-
-- KV cache size requirement per token and per request
-- TPU/GPU memory consumed by the model weights
-- TPU/GPU memory allocated for the KV cache
-- Maximum \# of requests you can approximately set (--max-num-seqs)
-
-This approach serves as a general rule of thumb.
-
-#### Latency-throughput tradeoff
-
-As with rightsizing the number of chips for your workload, consider adjusting `--max-num-seqs` to fine-tune the latency-throughput balance. Decreasing `--max-num-seqs` and/or increasing the number of chips can help reduce latency.
-
-`--max-num-seqs` defines the number of concurrent decode slots, effectively limiting the number of requests the server can process tokens for simultaneously. Increasing this value allows the server to pre-allocate more HBM to handle a higher number of concurrent requests, which can maximize overall throughput. However, this often increases the end-to-end (e2e) latency per request.
-
-Therefore, carefully tuning `--max-num-seqs` is crucial to achieving the desired balance between latency and throughput for your specific workload.
-
-In a similar way, `--max-num-batch-tokens` can be adjusted down to improve latency, or adjusted up to improve throughput.
-
-#### Compilation and Caching
-
-Coming from a GPU background, one of the key differences you'll notice with TPUs is an initial compilation step. TPUs are specialized accelerators (ASICs) that achieve maximum performance by executing pre-compiled, static computation graphs via the XLA compiler. Unlike GPUs, which can handle dynamic input shapes more flexibly, TPUs require a specific compiled graph for each tensor shape (e.g., batch size and sequence length) they process.
-
-To manage this, vLLM performs a one-time "warmup" process when you first launch the server. During this phase, it pre-compiles the model for various common input shapes and saves these compiled graphs to a cache on disk or remote storage (located at `~/.cache/vllm/xla_cache` by default). This process can range significantly, anywhere from a few minutes to an hour depending on the size of the model and context length used.
-
-Although the first compilation can take some time, for all subsequent server launches, vLLM can load these graphs directly from the cache, eliminating the compilation time for future runs.
-
-Use `VLLM_XLA_CACHE_PATH` environment variable to write to shareable storage for future deployed nodes (like when using autoscaling).
-
-#### Reducing compilation time
-
-This initial compilation time ranges significantly and is impacted by many of the arguments discussed in this optimization doc. Factors that influence the length of time to compile are things like model size and `--max-num-batch-tokens`. Other arguments you can tune are things like `VLLM_TPU_MOST_MODEL_LEN`.
-
-### Optimize based on your data
-
-#### max-model-len vs. most-model-len
-
-
-
-If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable.
-
-For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`.
-
-The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time.
-
-#### Padding
-
-For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.)
-
-The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests:
-
-1. the default exponential padding (pad to the nearest power of 2)
-2. bucket padding (pad to the nearest linearly increasing bucket).
-
-When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`.
-
-For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512].
-
-The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320.
-
-However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding.
-
-#### Quantization
-
-If possible, use the precision that matches the chip’s hardware acceleration:
-
-- v5e has int4/int8 hardware acceleration in the MXU
-- v6e has int4/int8 hardware acceleration in the MXU
-
-Supported quantized formats and features in vLLM on TPU [Jul '25]:
-
-- INT8 W8A8
-- INT8 W8A16
-- FP8 KV cache
-- [WIP] FP8 W8A8
-- [WIP] AWQ
-- [WIP] FP4 W4A8
-
-#### Parallelization
-
-Don't set TP to be less than the number of chips on a single-host deployment.
-
-Although it’s common to do this with GPUs, don't try to fragment 2 or 8 different workloads across 8 chips on a single host. If you need 1 or 4 chips, just create an instance with 1 or 4 chips (these are partial-host machine types).
-
-### Tune your workloads
-
-Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](../../benchmarks/auto_tune/README.md) to optimize your workloads for your use case.
-
-### Future Topics We'll Cover
-
-#### Profiling
-
-The auto-tuner provides a profile of optimized configurations as its final step. However, interpreting this profile can be challenging for new users. We plan to expand this section in the future with more detailed guidance. In the meantime, you can learn how to collect a TPU profile using vLLM's native profiling tools [here](../examples/offline_inference/profiling_tpu.md). This profile can provide valuable insights into your workload's performance.
-
-#### SPMD
-
-More details to come.
-
-**Want us to cover something that isn't listed here? Open up an issue please and cite this doc. We'd love to hear your questions or tips.**
diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md
index 09fd85a466eed..735bb2e205332 100644
--- a/docs/contributing/ci/update_pytorch_version.md
+++ b/docs/contributing/ci/update_pytorch_version.md
@@ -98,21 +98,6 @@ to warm it up so that future builds are faster.
-## Update dependencies
-
-Several vLLM dependencies like xFormers depend on PyTorch and need
-to be updated accordingly. Rather than waiting for all of them to publish new
-releases (which would take too much time), they can be built from
-source to unblock the update process.
-
-### xFormers
-
-```bash
-export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
-MAX_JOBS=16 uv pip install --system \
- --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
-```
-
## Update all the different vLLM platforms
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md
index a7b54f015c2da..e828de0adf3c2 100644
--- a/docs/contributing/model/basic.md
+++ b/docs/contributing/model/basic.md
@@ -133,8 +133,6 @@ We consider 3 different scenarios:
For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](../../../vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](../../../vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference.
The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config.
For the mamba layers themselves, please use the [`MambaMixer`](../../../vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](../../../vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes.
-Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations.
-V0-only classes and code will be removed in the very near future.
The model should also be added to the `MODELS_CONFIG_MAP` dictionary in [vllm/model_executor/models/config.py](../../../vllm/model_executor/models/config.py) to ensure that the runtime defaults are optimized.
For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](../../../vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](../../../vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together).
@@ -146,6 +144,7 @@ We use "mamba-like" to refer to layers that posses a state that is updated in-pl
For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`.
It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers.
Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](../../../vllm/v1/attention/backends/short_conv_attn.py) for examples of this.
+It is also worth noting that we should update `MAMBA_TYPE_TO_BACKEND_MAP` and `MambaAttentionBackendEnum` in [`registry.py`](../../../vllm/attention/backends/registry.py) when adding a new mamba backend.
Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it.
Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this.
The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended.
diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md
index 1c639f3533d47..0e636c87f38a4 100644
--- a/docs/deployment/docker.md
+++ b/docs/deployment/docker.md
@@ -82,8 +82,7 @@ DOCKER_BUILDKIT=1 docker build . \
## Building for Arm64/aarch64
-A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this requires the use
-of PyTorch Nightly and should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64.
+A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64.
!!! note
Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=`
@@ -94,7 +93,6 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `--
```bash
# Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB)
- python3 use_existing_torch.py
DOCKER_BUILDKIT=1 docker build . \
--file docker/Dockerfile \
--target vllm-openai \
@@ -102,7 +100,8 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `--
-t vllm/vllm-gh200-openai:latest \
--build-arg max_jobs=66 \
--build-arg nvcc_threads=2 \
- --build-arg torch_cuda_arch_list="9.0 10.0+PTX"
+ --build-arg torch_cuda_arch_list="9.0 10.0+PTX" \
+ --build-arg RUN_WHEEL_CHECK=false
```
!!! note
diff --git a/docs/deployment/frameworks/skypilot.md b/docs/deployment/frameworks/skypilot.md
index f4a984a6433e2..e9b0d5f0671c3 100644
--- a/docs/deployment/frameworks/skypilot.md
+++ b/docs/deployment/frameworks/skypilot.md
@@ -4,7 +4,7 @@
-vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with [SkyPilot](https://github.com/skypilot-org/skypilot), an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc, can be found in [SkyPilot AI gallery](https://skypilot.readthedocs.io/en/latest/gallery/index.html).
+vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with [SkyPilot](https://github.com/skypilot-org/skypilot), an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc., can be found in [SkyPilot AI gallery](https://skypilot.readthedocs.io/en/latest/gallery/index.html).
## Prerequisites
diff --git a/docs/deployment/integrations/kserve.md b/docs/deployment/integrations/kserve.md
index edf79fca4f93e..37b29aa1a4876 100644
--- a/docs/deployment/integrations/kserve.md
+++ b/docs/deployment/integrations/kserve.md
@@ -2,4 +2,4 @@
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
-Please see [this guide](https://kserve.github.io/website/latest/modelserving/v1beta1/llm/huggingface/) for more details on using vLLM with KServe.
+Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
diff --git a/docs/design/debug_vllm_compile.md b/docs/design/debug_vllm_compile.md
index 3b454e851b54e..408d2878309dd 100644
--- a/docs/design/debug_vllm_compile.md
+++ b/docs/design/debug_vllm_compile.md
@@ -9,7 +9,7 @@ TL;DR:
|----------|----------|-------------|
| --enforce-eager | enforce_eager=True | Turn off torch.compile and CUDAGraphs |
| -O.mode=0 | mode=CompilationMode.NONE | Turn off torch.compile only |
-| -O.cudagraph_mode=NONE | compilation_config=CompilationConfig(mode=CompilationMode.NONE) | Turn off CUDAGraphs only |
+| -O.cudagraph_mode=NONE | compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE) | Turn off CUDAGraphs only |
| -O.backend=eager | compilation_config=CompilationConfig(backend='eager') | Turn off TorchInductor |
## vLLM-torch.compile overview
@@ -151,6 +151,76 @@ To avoid this, please either:
2. wrap the branching logic into a custom operator. TorchDynamo does not
trace into custom operators.
+## Debugging constraint violations and dynamic shapes guards issues
+
+Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
+attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
+These guards typically appear when framework code, custom passes, or user code branches based on
+dynamic shape values.
+
+**Example:**
+
+```python
+if x > 10:
+ # path A
+else:
+ # path B
+```
+
+This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
+
+**vLLM's Assumption:**
+vLLM assumes that all guards added by torch.compile are safe to drop and will not
+constrain the compiled graph to specific input shapes. When this assumption is violated,
+it can cause issues that users need to debug.
+Some side effects that indicates this assumption is violated are runtime errors
+or `ConstraintViolationErrors`.
+
+A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
+a single value. If you encounter a constraint violation error or suspect that a dynamic
+shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
+help debug the issue:
+
+```sh
+# Online - using unbacked mode
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
+
+# Online - using backed_size_oblivious mode
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
+```
+
+```py
+# Offline - using unbacked mode
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+LLM(model, compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
+))
+
+# Offline - using backed_size_oblivious mode
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+LLM(model, compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
+))
+```
+
+These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
+
+- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
+- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
+
+For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
+
+### Printing guards
+
+To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
+
+```sh
+TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
+```
+
+Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
+causing guards to be added incorrectly.
+
## Debugging TorchInductor
TorchInductor takes a captured graph and then compiles it down to some Python code
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 7663b82266f0b..f0d5a3e934f39 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -1,22 +1,22 @@
-# Fused MoE Kernel features
+# Fused MoE Kernel Features
The purpose of this document is to provide an overview of the various MoE kernels (both modular and non-modular) so it will be easier to select an appropriate set of kernels for any particular situation. This includes information about the all2all backends used by modular kernels.
## Fused MoE Modular All2All backends
-There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` sub-classes provide an interface for each all2all backend.
+There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` subclasses provide an interface for each all2all backend.
The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support.
-The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, the finalize step requires the same format. All the backend `prepare` methods expect activations in standard format and all the `finalize methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
+The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
-The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports. e.g. deepep_high_throughput supports only block-quantized fp8 format, any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 w/per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
+The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step).
-Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass, for non-modular kernels, it is up to the experts function to deal with this flag.
+Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
-unless otherwise specified, backends are controlled via `VLLM_ALL2ALL_BACKEND`. All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP w/o EP.
+Unless otherwise specified, backends are controlled via `VLLM_ALL2ALL_BACKEND`. All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
-| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Sub-class |
-|---------------------------------------|--------------------|-----------------|------------------------|-------|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] |
-| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
-| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
-| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
-| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
-| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
-| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
-| MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
-| BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
+| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Subclass |
+|---------|--------------------|--------------|---------------|-------|-----------------------|-----------|
+| naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] |
+| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
+| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
+| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
+| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
+| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
+| MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
+| BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
2. A,T quantization occurs after dispatch.
3. All quantization happens after dispatch.
4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency")
- 5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API.
+ 5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs without dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API.
6. This depends on the experts implementation.
---
@@ -66,44 +65,43 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
-## Fused MoE Experts Kernels
+## Fused Experts Kernels
-The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
+There are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters, so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
-Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx`, `DeepEPLLPrepareAndFinalize`.
+Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx` and `DeepEPLLPrepareAndFinalize`.
Similar to the backend kernels, each experts kernel only supports certain quantization formats. For non-modular experts, the activations will be in the original type and quantized internally by the kernel. Modular experts will expect the activations to already be in the quantized format. Both types of experts will yield outputs in the original activation type.
-Each experts kernel supports one or more activation functions, e.g. silu, gelu that are applied to the intermediate results.
+Each experts kernel supports one or more activation functions, e.g. silu or gelu, which are applied to the intermediate results.
As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts.
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`.
-To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels must have compatible activation formats, quantization types and quantization formats.
+To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
-| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
-|------------------------------|-----------------------|------------------|---------------|-------------------------------------------------------------|-----------------------|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| triton | standard | all1 | G,A,T | silu, gelu,swigluoai,silu_no_mul,gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
-| triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
-| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
-| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
-| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
-| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
-| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
-| deep gemm+triton2 | standard,batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
-| marlin | standard | 3 | 3 | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
-| marlin experts | standard,batched | N/A | N/A | silu,swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
-| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
-| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
-| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
-| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
-| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
-| naive batched4 | batched | int8,fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
+| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
+|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
+| triton | standard | all1 | G,A,T | silu, gelu,swigluoai,silu_no_mul,gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
+| triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
+| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
+| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
+| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
+| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
+| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
+| deep gemm+triton2 | standard,batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
+| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
+| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
+| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
+| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
+| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
+| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
+| naive batched4 | batched | int8,fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
- 2. A dispatcher wrapper around triton and deep gemm experts. Will select based on type + shape + quantization params
+ 2. A dispatcher wrapper around triton and deep gemm experts. Will select based on type + shape + quantization params
3. uint4, uint8, fp8, fp4
4. This is a naive implementation of experts that supports batched format. Mainly used for testing.
5. The `activation` parameter is ignored and SwiGlu is used by default instead.
@@ -113,8 +111,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts.
-| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
-|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
-| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,`TritonExperts`,`TritonOrDeepGemmExperts`,`CutlassExpertsFp8`, `MarlinExperts` |
-| deepep_low_latency,pplx | `DeepEPLLPrepareAndFinalize`,`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,`BatchedTritonExperts`,`BatchedTritonOrDeepGemmExperts`,`CutlassBatchedExpertsFp8`,`BatchedMarlinExperts`|
-| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
+| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
+|---------|-----------------------------------------|----------------------------------------------|
+| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,`TritonExperts`,`TritonOrDeepGemmExperts`,`CutlassExpertsFp8`, `MarlinExperts` |
+| deepep_low_latency,pplx | `DeepEPLLPrepareAndFinalize`,`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,`BatchedTritonExperts`,`BatchedTritonOrDeepGemmExperts`,`CutlassBatchedExpertsFp8`,`BatchedMarlinExperts` |
+| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md
index dc2f7c4aed3c3..9c84889f7f03d 100644
--- a/docs/design/plugin_system.md
+++ b/docs/design/plugin_system.md
@@ -4,7 +4,7 @@ The community frequently requests the ability to extend vLLM with custom feature
## How Plugins Work in vLLM
-Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview](arch_overview.md)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work.
+Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview](arch_overview.md)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_plugins_by_group][vllm.plugins.load_plugins_by_group] function in the `vllm.plugins` module.
## How vLLM Discovers Plugins
@@ -49,7 +49,7 @@ Every plugin has three parts:
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
-- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name.
+- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre-/post-processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name.
- **Stat logger plugins** (with group name `vllm.stat_logger_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree loggers into vLLM. The entry point should be a class that subclasses StatLoggerBase.
@@ -57,6 +57,100 @@ Every plugin has three parts:
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
+### Platform plugins guidelines
+
+1. Create a platform plugin project, for example, `vllm_add_dummy_platform`. The project structure should look like this:
+
+ ```shell
+ vllm_add_dummy_platform/
+ ├── vllm_add_dummy_platform/
+ │ ├── __init__.py
+ │ ├── my_dummy_platform.py
+ │ ├── my_dummy_worker.py
+ │ ├── my_dummy_attention.py
+ │ ├── my_dummy_device_communicator.py
+ │ ├── my_dummy_custom_ops.py
+ ├── setup.py
+ ```
+
+2. In the `setup.py` file, add the following entry point:
+
+ ```python
+ setup(
+ name="vllm_add_dummy_platform",
+ ...
+ entry_points={
+ "vllm.platform_plugins": [
+ "my_dummy_platform = vllm_add_dummy_platform:register"
+ ]
+ },
+ ...
+ )
+ ```
+
+ Please make sure `vllm_add_dummy_platform:register` is a callable function and returns the platform class's fully qualified name. for example:
+
+ ```python
+ def register():
+ return "vllm_add_dummy_platform.my_dummy_platform.MyDummyPlatform"
+ ```
+
+3. Implement the platform class `MyDummyPlatform` in `my_dummy_platform.py`. The platform class should inherit from `vllm.platforms.interface.Platform`. Please follow the interface to implement the functions one by one. There are some important functions and properties that should be implemented at least:
+
+ - `_enum`: This property is the device enumeration from [PlatformEnum][vllm.platforms.interface.PlatformEnum]. Usually, it should be `PlatformEnum.OOT`, which means the platform is out-of-tree.
+ - `device_type`: This property should return the type of the device which pytorch uses. For example, `"cpu"`, `"cuda"`, etc.
+ - `device_name`: This property is set the same as `device_type` usually. It's mainly used for logging purposes.
+ - `check_and_update_config`: This function is called very early in the vLLM's initialization process. It's used for plugins to update the vllm configuration. For example, the block size, graph mode config, etc, can be updated in this function. The most important thing is that the **worker_cls** should be set in this function to let vLLM know which worker class to use for the worker process.
+ - `get_attn_backend_cls`: This function should return the attention backend class's fully qualified name.
+ - `get_device_communicator_cls`: This function should return the device communicator class's fully qualified name.
+
+4. Implement the worker class `MyDummyWorker` in `my_dummy_worker.py`. The worker class should inherit from [WorkerBase][vllm.v1.worker.worker_base.WorkerBase]. Please follow the interface to implement the functions one by one. Basically, all interfaces in the base class should be implemented, since they are called here and there in vLLM. To make sure a model can be executed, the basic functions should be implemented are:
+
+ - `init_device`: This function is called to set up the device for the worker.
+ - `initialize_cache`: This function is called to set cache config for the worker.
+ - `load_model`: This function is called to load the model weights to device.
+ - `get_kv_cache_spaces`: This function is called to generate the kv cache spaces for the model.
+ - `determine_available_memory`: This function is called to profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs.
+ - `initialize_from_config`: This function is called to allocate device KV cache with the specified kv_cache_config
+ - `execute_model`: This function is called every step to inference the model.
+
+ Additional functions that can be implemented are:
+
+ - If the plugin wants to support sleep mode feature, please implement the `sleep` and `wakeup` functions.
+ - If the plugin wants to support graph mode feature, please implement the `compile_or_warm_up_model` function.
+ - If the plugin wants to support speculative decoding feature, please implement the `take_draft_token_ids` function.
+ - If the plugin wants to support lora feature, please implement the `add_lora`,`remove_lora`,`list_loras` and `pin_lora` functions.
+ - If the plugin wants to support data parallelism feature, please implement the `execute_dummy_batch` functions.
+
+ Please look at the worker base class [WorkerBase][vllm.v1.worker.worker_base.WorkerBase] for more functions that can be implemented.
+
+5. Implement the attention backend class `MyDummyAttention` in `my_dummy_attention.py`. The attention backend class should inherit from [AttentionBackend][vllm.attention.backends.abstract.AttentionBackend]. It's used to calculate attentions with your device. Take `vllm.v1.attention.backends` as examples, it contains many attention backend implementations.
+
+6. Implement custom ops for high performance. Most ops can be ran by pytorch native implementation, while the performance may not be good. In this case, you can implement specific custom ops for your plugins. Currently, there are kinds of custom ops vLLM supports:
+
+ - pytorch ops
+ there are 3 kinds of pytorch ops:
+
+ - `communicator ops`: Device communicator op. Such as all-reduce, all-gather, etc.
+ Please implement the device communicator class `MyDummyDeviceCommunicator` in `my_dummy_device_communicator.py`. The device communicator class should inherit from [DeviceCommunicatorBase][vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase].
+ - `common ops`: Common ops. Such as matmul, softmax, etc.
+ Please implement the common ops by register oot way. See more detail in [CustomOp][vllm.model_executor.custom_op.CustomOp] class.
+ - `csrc ops`: C++ ops. This kind of ops are implemented in C++ and are registered as torch custom ops.
+ Following csrc module and `vllm._custom_ops` to implement your ops.
+
+ - triton ops
+ Custom way doesn't work for triton ops now.
+
+7. (optional) Implement other plugable modules, such as lora, graph backend, quantization, mamba attention backend, etc.
+
## Compatibility Guarantee
-vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
+vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets.
+
+The interface for the model/module may change during vLLM's development. If you see any deprecation log info, please upgrade your plugin to the latest version.
+
+## Deprecation announcement
+
+!!! warning "Deprecations"
+ - `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
+ - `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
diff --git a/docs/design/prefix_caching.md b/docs/design/prefix_caching.md
index bd4070f381d81..cf792fdabe1a6 100644
--- a/docs/design/prefix_caching.md
+++ b/docs/design/prefix_caching.md
@@ -1,6 +1,6 @@
# Automatic Prefix Caching
-Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints (e.g., OpenAI, Anthropic, etc) and most open source LLM inference frameworks (e.g., SGLang).
+Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints (e.g., OpenAI, Anthropic, etc.) and most open source LLM inference frameworks (e.g., SGLang).
While there are many ways to implement prefix caching, vLLM chooses a hash-based approach. Specifically, we hash each kv-cache block by the tokens in the block and the tokens in the prefix before the block:
@@ -94,9 +94,6 @@ To improve privacy in shared environments, vLLM supports isolating prefix cache
With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others.
-!!! note
- Cache isolation is not supported in engine V0.
-
## Data Structure
The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified):
diff --git a/docs/design/torch_compile.md b/docs/design/torch_compile.md
index 27edc4f89201d..7b0b2c1e96978 100644
--- a/docs/design/torch_compile.md
+++ b/docs/design/torch_compile.md
@@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
+## Dynamic shapes and vllm guard dropping
+
+`torch.compile` is designed to guard on dynamic shapes with no hesitation
+when needed. This contradicts with vLLM's `torch.compile` approach of
+dropping the guards since many of those guards could be material.
+
+`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
+`torch.compile` guards on `backed` dynamic shapes and does not provide a
+guarantee that no guards will be added to them. User code, dynamo,
+inductor, and autograd all can add guards. Moreover, for 0/1
+specializations, backed symbols are specialized unconditionally to 0, 1,
+or >=2 even without encountering a branching on those ranges.
+
+On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
+on and are not 0/1 specialized. However, there is a possibility of
+throwing a data dependent error when a branch that requires their value is
+encountered and no explicit unbacked handling is defined. The framework is
+converging to a state where it won't throw DDE but rather pick general
+paths. One downside of using unbacked is missed optimization opportunities
+due to either perf bugs or picking general paths, also using a fixed
+non-example input-based hint (this will be fixed soon with override_hint
+API). An example of picking general paths is assuming input not contiguous
+in functions call contiguous() and reshape() when can't be symbolically proven
+with a change of introducing a clone.
+
+`backed_size_oblivious` is a flag that enables treating backed symbols as
+unbacked wherever explicit handling for unbacked is defined. With this
+mode, 0/1 specializations are mostly avoided in framework code and the
+default 0/1 specialization does not happen. However, there is still no
+guarantee that torch.compile won't guard, especially due to user code or
+custom passes. `backed_size_oblivious` is experimental in PyTorch compile
+and could be deprecated. That said, it's a safer option to use than
+`backed` and the probability of reducing performance is lower than
+`unbacked`.
+
+### Configuring Dynamic Shapes
+
+The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
+setting the `type` field. You can choose between three modes:
+`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
+
+#### Offline Inference Example (Using LLM class)
+
+When using the `LLM` class for offline inference, you can configure dynamic
+shapes through the `compilation_config` parameter:
+
+```python
+from vllm import LLM, SamplingParams
+from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
+
+# Example: Using backed_size_oblivious (experimental, safer than backed)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B",
+ compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(
+ type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
+ )
+ )
+)
+
+# Example: Using unbacked (strongest guarantee against guards)
+llm = LLM(
+ model="meta-llama/Llama-3.2-1B",
+ compilation_config=CompilationConfig(
+ dynamic_shapes_config=DynamicShapesConfig(
+ type=DynamicShapesType.UNBACKED
+ )
+ )
+)
+
+# Generate outputs
+prompts = ["Hello, my name is", "The future of AI is"]
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+outputs = llm.generate(prompts, sampling_params)
+```
+
+#### Online Serving Example (Using vllm serve)
+
+When using `vllm serve` for online serving, you can configure dynamic shapes
+through the `--compilation-config` flag:
+
+```bash
+# Example: Using unbacked
+vllm serve meta-llama/Llama-3.2-1B \
+ --compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'
+
+
+# Alternative: Using dot notation (simpler for single values)
+vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
+```
+
+#### Choosing the Right Mode
+
+- **BACKED** (default): Use when you're willing to accept potential unsafe dropping of guards
+for maximal performance. Guard could be unsoundly added and then ignored.
+
+- **UNBACKED** Use when you need the strongest guarantee against guards.
+ This is the most conservative option but may miss some optimization opportunities.
+
+- **BACKED_SIZE_OBLIVIOUS**: Use when you want a balance between avoiding guards
+ and performance. This experimental mode is safer than BACKED but still not as
+ conservative as UNBACKED.
+
## Python Code Compilation
In the very verbose logs, we can see:
@@ -122,7 +225,7 @@ When all the shapes are known, `torch.compile` can compare different configs, an
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
- mm 0.0160 ms 81.6%
+ mm 0.0160 ms 81.6%
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
diff --git a/docs/features/README.md b/docs/features/README.md
index ad9de9ff8f368..5faf3768f3214 100644
--- a/docs/features/README.md
+++ b/docs/features/README.md
@@ -59,20 +59,23 @@ th:not(:first-child) {
### Feature x Hardware
-| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU | Intel GPU |
-|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----| ------------|
-| [CP](../configuration/optimization.md#chunked-prefill) | [❌](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [APC](automatic_prefix_caching.md) | [❌](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [🟠](https://github.com/vllm-project/vllm/issues/26963) |
-| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [❌](https://github.com/vllm-project/vllm/issues/26970) |
-| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
-| enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |
-| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | [🟠](https://github.com/vllm-project/vllm/issues/26965) |
-| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
-| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
-| async output | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ |
-| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ |
-| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
-| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
-| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
+| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | Intel GPU |
+|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------| ------------|
+| [CP](../configuration/optimization.md#chunked-prefill) | [❌](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [APC](automatic_prefix_caching.md) | [❌](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26963) |
+| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/26970) |
+| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
+| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26965) |
+| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ |
+| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| async output | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |
+| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ✅ |
+| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+
+!!! note
+ For information on feature support on Google TPU, please refer to the [TPU-Inference Recommended Models and Features](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/) documentation.
diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md
index 5f684604e6031..4656ee43ea251 100644
--- a/docs/features/multimodal_inputs.md
+++ b/docs/features/multimodal_inputs.md
@@ -365,6 +365,8 @@ You must enable this feature via `enable_mm_embeds=True`.
The vLLM engine may crash if incorrect shape of embeddings is passed.
Only enable this flag for trusted users!
+#### Image Embeddings
+
??? code
```python
@@ -441,6 +443,36 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
print(generated_text)
```
+#### Audio Embeddings
+
+You can pass pre-computed audio embeddings similar to image embeddings:
+
+??? code
+
+ ```python
+ from vllm import LLM
+ import torch
+
+ # Enable audio embeddings support
+ llm = LLM(model="fixie-ai/ultravox-v0_5-llama-3_2-1b", enable_mm_embeds=True)
+
+ # Refer to the HuggingFace repo for the correct format to use
+ prompt = "USER: \nWhat is in this audio?\nASSISTANT:"
+
+ # Load pre-computed audio embeddings
+ # torch.Tensor of shape (1, audio_feature_size, hidden_size of LM)
+ audio_embeds = torch.load(...)
+
+ outputs = llm.generate({
+ "prompt": prompt,
+ "multi_modal_data": {"audio": audio_embeds},
+ })
+
+ for o in outputs:
+ generated_text = o.outputs[0].text
+ print(generated_text)
+ ```
+
## Online Serving
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.
diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md
index 1ce038f4d6525..f0e25e31aa0b3 100644
--- a/docs/features/nixl_connector_usage.md
+++ b/docs/features/nixl_connector_usage.md
@@ -158,7 +158,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
## Experimental Feature
-### Heterogenuous KV Layout support
+### Heterogeneous KV Layout support
Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration
diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md
index 74f005c496ee5..7b5287bad3bb8 100644
--- a/docs/features/quantization/README.md
+++ b/docs/features/quantization/README.md
@@ -43,24 +43,27 @@ th:not(:first-child) {
}
-| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | Google TPU |
-|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|
-| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ |
-| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ |
-| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ |
-| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
-| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
-| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ |
+| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU |
+|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|
+| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ |
+| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ |
+| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
+| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ |
+| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
+| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
+| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
+| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
+| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
+| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
+| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- ✅︎ indicates that the quantization method is supported on the specified hardware.
- ❌ indicates that the quantization method is not supported on the specified hardware.
+!!! note
+ For information on quantization support on Google TPU, please refer to the [TPU-Inference Recommended Models and Features](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/) documentation.
+
!!! note
This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods.
diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md
index 0c5111fb8af0d..d4a6176b236f1 100644
--- a/docs/features/quantization/fp8.md
+++ b/docs/features/quantization/fp8.md
@@ -60,7 +60,7 @@ Since simple RTN does not require data for weight quantization and the activatio
??? code
```python
- from llmcompressor.transformers import oneshot
+ from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
# Configure the simple PTQ quantization
diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md
index 5e86e9388f328..9875bc44c9144 100644
--- a/docs/features/quantization/inc.md
+++ b/docs/features/quantization/inc.md
@@ -22,9 +22,6 @@ export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxab
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
```
-!!! tip
- If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
-
!!! tip
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md
index 035e7ea291f9e..9752039097d63 100644
--- a/docs/features/quantization/int4.md
+++ b/docs/features/quantization/int4.md
@@ -80,7 +80,7 @@ Now, apply the quantization algorithms:
??? code
```python
- from llmcompressor.transformers import oneshot
+ from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md
index ec8a77f74ffef..701ca6378cb16 100644
--- a/docs/features/quantization/int8.md
+++ b/docs/features/quantization/int8.md
@@ -87,7 +87,7 @@ Now, apply the quantization algorithms:
??? code
```python
- from llmcompressor.transformers import oneshot
+ from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
diff --git a/docs/features/quantization/quantized_kvcache.md b/docs/features/quantization/quantized_kvcache.md
index 56cf057678be6..d26a5e217f314 100644
--- a/docs/features/quantization/quantized_kvcache.md
+++ b/docs/features/quantization/quantized_kvcache.md
@@ -78,7 +78,7 @@ Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
- from llmcompressor.transformers import oneshot
+ from llmcompressor import oneshot
# Select model and load it
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md
index bd7bc186e13aa..c54d7d2251999 100644
--- a/docs/features/quantization/quark.md
+++ b/docs/features/quantization/quark.md
@@ -306,7 +306,7 @@ As examples, we provide some ready-to-use quantized mixed precision model to sho
### 2. inference the quantized mixed precision model in vLLM
-Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow:
+Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follows:
```bash
lm_eval --model vllm \
diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md
index e38627c707884..7d52891bea7b9 100644
--- a/docs/features/structured_outputs.md
+++ b/docs/features/structured_outputs.md
@@ -7,7 +7,7 @@ This document shows you some examples of the different options that are
available to generate structured outputs.
!!! warning
- If you are still using the following deprecated API fields, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
+ If you are still using the following deprecated API fields which were removed in v0.12.0, please update your code to use `structured_outputs` as demonstrated in the rest of this document:
- `guided_json` -> `{"structured_outputs": {"json": ...}}` or `StructuredOutputsParams(json=...)`
- `guided_regex` -> `{"structured_outputs": {"regex": ...}}` or `StructuredOutputsParams(regex=...)`
diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md
index 7e6c69e717dba..dd79ba19b7247 100644
--- a/docs/features/tool_calling.md
+++ b/docs/features/tool_calling.md
@@ -142,7 +142,7 @@ Flags: `--tool-call-parser hermes`
Supported models:
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
-* Additional mistral function-calling models are compatible as well.
+* Additional Mistral function-calling models are compatible as well.
Known issues:
@@ -158,12 +158,25 @@ Known issues:
Recommended flags:
-1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
+1. To use the official Mistral AI's format:
- `--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
+ `--tool-call-parser mistral`
-2. To use the default Transformers tokenization backend:
- `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
+2. To use the Transformers format when available:
+
+ `--tokenizer_mode hf --config_format hf --load_format hf --tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
+
+!!! note
+ Models officially released by Mistral AI have two possible formats:
+
+ 1. The official format that is used by default with `auto` or `mistral` arguments:
+
+ `--tokenizer_mode mistral --config_format mistral --load_format mistral`
+ This format uses [mistral-common](https://github.com/mistralai/mistral-common), the Mistral AI's tokenizer backend.
+
+ 2. The Transformers format, when available, that is used with `hf` arguments:
+
+ `--tokenizer_mode hf --config_format hf --load_format hf --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`)
diff --git a/docs/getting_started/installation/gpu.cuda.inc.md b/docs/getting_started/installation/gpu.cuda.inc.md
index b2d0d64a2d355..601d3659af886 100644
--- a/docs/getting_started/installation/gpu.cuda.inc.md
+++ b/docs/getting_started/installation/gpu.cuda.inc.md
@@ -158,10 +158,7 @@ uv pip install -e .
##### Use an existing PyTorch installation
-There are scenarios where the PyTorch dependency cannot be easily installed with `uv`, e.g.:
-
-- Building vLLM with PyTorch nightly or a custom PyTorch build.
-- Building vLLM with aarch64 and CUDA (GH200), where the PyTorch wheels are not available on PyPI. Currently, only the PyTorch nightly has wheels for aarch64 with CUDA. You can run `uv pip install --index-url https://download.pytorch.org/whl/nightly/cu128 torch torchvision torchaudio` to [install PyTorch nightly](https://pytorch.org/get-started/locally/) and then build vLLM on top of it.
+There are scenarios where the PyTorch dependency cannot be easily installed with `uv`, for example, when building vLLM with non-default PyTorch builds (like nightly or a custom build).
To build vLLM using an existing PyTorch installation:
diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md
index cfc8b4d9838a7..94920dc5306b3 100644
--- a/docs/getting_started/quickstart.md
+++ b/docs/getting_started/quickstart.md
@@ -283,10 +283,10 @@ Currently, vLLM supports multiple backends for efficient Attention computation a
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
-- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`.
+- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
-For AMD ROCm, you can futher control the specific Attention implementation using the following variables:
+For AMD ROCm, you can further control the specific Attention implementation using the following variables:
- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
diff --git a/docs/models/hardware_supported_models/cpu.md b/docs/models/hardware_supported_models/cpu.md
new file mode 100644
index 0000000000000..0832755f8fbe2
--- /dev/null
+++ b/docs/models/hardware_supported_models/cpu.md
@@ -0,0 +1,26 @@
+# CPU - Intel® Xeon®
+
+## Supported Models
+
+### Text-only Language Models
+
+| Model | Architecture | Supported |
+|--------------------------------------|-------------------------------------------|-----------|
+| meta-llama/Llama-3.1 / 3.3 | LlamaForCausalLM | ✅ |
+| meta-llama/Llama-4-Scout | Llama4ForConditionalGeneration | ✅ |
+| meta-llama/Llama-4-Maverick | Llama4ForConditionalGeneration | ✅ |
+| ibm-granite/granite (Granite-MOE) | GraniteMoeForCausalLM | ✅ |
+| Qwen/Qwen3 | Qwen3ForCausalLM | ✅ |
+| zai-org/GLM-4.5 | GLMForCausalLM | ✅ |
+| google/gemma | GemmaForCausalLM | ✅ |
+
+### Multimodal Language Models
+
+| Model | Architecture | Supported |
+|--------------------------------------|-------------------------------------------|-----------|
+| Qwen/Qwen2.5-VL | Qwen2VLForConditionalGeneration | ✅ |
+| openai/whisper | WhisperForConditionalGeneration | ✅ |
+
+✅ Runs and optimized.
+🟨 Runs and correct but not optimized to green yet.
+❌ Does not pass accuracy test or does not run.
diff --git a/docs/models/hardware_supported_models/tpu.md b/docs/models/hardware_supported_models/tpu.md
deleted file mode 100644
index 7b0a5ba6e72da..0000000000000
--- a/docs/models/hardware_supported_models/tpu.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# TPU
-
-## Supported Models
-
-### Text-only Language Models
-
-| Model | Architecture | Supported |
-|-----------------------------------------------------|--------------------------------|-----------|
-| mistralai/Mixtral-8x7B-Instruct-v0.1 | MixtralForCausalLM | 🟨 |
-| mistralai/Mistral-Small-24B-Instruct-2501 | MistralForCausalLM | ✅ |
-| mistralai/Codestral-22B-v0.1 | MistralForCausalLM | ✅ |
-| mistralai/Mixtral-8x22B-Instruct-v0.1 | MixtralForCausalLM | ❌ |
-| meta-llama/Llama-3.3-70B-Instruct | LlamaForCausalLM | ✅ |
-| meta-llama/Llama-3.1-8B-Instruct | LlamaForCausalLM | ✅ |
-| meta-llama/Llama-3.1-70B-Instruct | LlamaForCausalLM | ✅ |
-| meta-llama/Llama-4-* | Llama4ForConditionalGeneration | ❌ |
-| microsoft/Phi-3-mini-128k-instruct | Phi3ForCausalLM | 🟨 |
-| microsoft/phi-4 | Phi3ForCausalLM | ❌ |
-| google/gemma-3-27b-it | Gemma3ForConditionalGeneration | 🟨 |
-| google/gemma-3-4b-it | Gemma3ForConditionalGeneration | ❌ |
-| deepseek-ai/DeepSeek-R1 | DeepseekV3ForCausalLM | ❌ |
-| deepseek-ai/DeepSeek-V3 | DeepseekV3ForCausalLM | ❌ |
-| RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | LlamaForCausalLM | ✅ |
-| RedHatAI/Meta-Llama-3.1-70B-Instruct-quantized.w8a8 | LlamaForCausalLM | ✅ |
-| Qwen/Qwen3-8B | Qwen3ForCausalLM | ✅ |
-| Qwen/Qwen3-32B | Qwen3ForCausalLM | ✅ |
-| Qwen/Qwen2.5-7B-Instruct | Qwen2ForCausalLM | ✅ |
-| Qwen/Qwen2.5-32B | Qwen2ForCausalLM | ✅ |
-| Qwen/Qwen2.5-14B-Instruct | Qwen2ForCausalLM | ✅ |
-| Qwen/Qwen2.5-1.5B-Instruct | Qwen2ForCausalLM | 🟨 |
-
-✅ Runs and optimized.
-🟨 Runs and correct but not optimized to green yet.
-❌ Does not pass accuracy test or does not run.
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 3c9295b6414a7..5f64f5182477d 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -79,7 +79,9 @@ To make your model compatible with the Transformers modeling backend, it needs:
1. Add `is_causal = False` to `MyAttention`.
- If your model is mixture-of-experts (MoE):
1. Your sparse MoE block must have an attribute called `experts`.
- 2. The class of `experts` (`MyExperts`) must inherit from `nn.ModuleList`.
+ 2. The class of `experts` (`MyExperts`) must either:
+ - Inherit from `nn.ModuleList` (naive).
+ - Or contain all 3D `nn.Parameters` (packed).
3. `MyExperts.forward` must accept `hidden_states`, `top_k_index`, `top_k_weights`.
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
3. `MyModel` must contain `_supports_attention_backend = True`.
@@ -422,7 +424,7 @@ th {
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ |
-| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ |
+| `OLMo3ForCausalLM` | OLMo3 | `allenai/Olmo-3-7B-Instruct`, `allenai/Olmo-3-32B-Think`, etc. | ✅︎ | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ |
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
@@ -435,6 +437,7 @@ th {
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ |
+| `Plamo3ForCausalLM` | PLaMo3 | `pfnet/plamo-3-nict-2b-base`, `pfnet/plamo-3-nict-8b-base`, etc. | | ✅︎ |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ |
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ |
@@ -678,6 +681,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
+| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+ ) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |
@@ -699,6 +703,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ |
| `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
+| `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + IE+ | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
| `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | |
| `PaddleOCRVLForConditionalGeneration` | Paddle-OCR | T + I+ | `PaddlePaddle/PaddleOCR-VL`, etc. | | |
diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md
index 23df3963823aa..e3280bd15b55c 100644
--- a/docs/serving/openai_compatible_server.md
+++ b/docs/serving/openai_compatible_server.md
@@ -49,7 +49,8 @@ We currently support the following OpenAI APIs:
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
- - *Note: `parallel_tool_calls` and `user` parameters are ignored.*
+ - *Note: `user` parameter is ignored.*
+ - *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls.
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.md).
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md
index 14cd3b057791c..a32840ea73b9a 100644
--- a/docs/serving/parallelism_scaling.md
+++ b/docs/serving/parallelism_scaling.md
@@ -118,14 +118,16 @@ The common practice is to set the tensor parallel size to the number of GPUs in
```bash
vllm serve /path/to/the/model/in/the/container \
--tensor-parallel-size 8 \
- --pipeline-parallel-size 2
+ --pipeline-parallel-size 2 \
+ --distributed-executor-backend ray
```
Alternatively, you can set `tensor_parallel_size` to the total number of GPUs in the cluster:
```bash
vllm serve /path/to/the/model/in/the/container \
- --tensor-parallel-size 16
+ --tensor-parallel-size 16 \
+ --distributed-executor-backend ray
```
## Optimizing network communication for tensor parallelism
diff --git a/docs/usage/reproducibility.md b/docs/usage/reproducibility.md
index d8a1943209c1e..a8e49d0a3398f 100644
--- a/docs/usage/reproducibility.md
+++ b/docs/usage/reproducibility.md
@@ -1,24 +1,23 @@
# Reproducibility
-vLLM does not guarantee the reproducibility of the results by default, for the sake of performance. You need to do the following to achieve
+vLLM does not guarantee the reproducibility of the results by default, for the sake of performance. To achieve
reproducible results:
-- For V1: Turn off multiprocessing to make the scheduling deterministic by setting `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
-- For V0: Set the global seed (see below).
+- In offline mode, you can either set `VLLM_ENABLE_V1_MULTIPROCESSING=0` which makes scheduling deterministic,
+ or enable [batch invariance](../features/batch_invariance.md) to make the outputs insensitive to scheduling.
+- In online mode, you can only enable [batch invariance](../features/batch_invariance.md).
Example: [examples/offline_inference/reproducibility.py](../../examples/offline_inference/reproducibility.py)
!!! warning
- Applying the above settings [changes the random state in user code](#locality-of-random-state).
+ Setting `VLLM_ENABLE_V1_MULTIPROCESSING=0` will change the random state of user code
+ (i.e. the code that constructs [LLM][vllm.LLM] class).
!!! note
Even with the above settings, vLLM only provides reproducibility
when it runs on the same hardware and the same vLLM version.
- Also, the online serving API (`vllm serve`) does not support reproducibility
- because it is almost impossible to make the scheduling deterministic in the
- online setting.
## Setting the global seed
@@ -26,27 +25,17 @@ The `seed` parameter in vLLM is used to control the random states for various ra
If a specific seed value is provided, the random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly.
-However, in some cases, setting the seed will also [change the random state in user code](#locality-of-random-state).
-
### Default Behavior
-In V0, the `seed` parameter defaults to `None`. When the `seed` parameter is `None`, the random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that each run of vLLM will produce different results if `temperature > 0`, as expected.
-
In V1, the `seed` parameter defaults to `0` which sets the random state for each worker, so the results will remain consistent for each vLLM run even if `temperature > 0`.
+It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs
+for workflows such as speculative decoding. For more information, see:
+
!!! note
- It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs
- for workflows such as speculative decoding.
-
- For more information, see:
+ The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM
+ only if the workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
-### Locality of random state
-
-The random state in user code (i.e. the code that constructs [LLM][vllm.LLM] class) is updated by vLLM under the following conditions:
-
-- For V0: The seed is specified.
-- For V1: The workers are run in the same process as user code, i.e.: `VLLM_ENABLE_V1_MULTIPROCESSING=0`.
-
-By default, these conditions are not active so you can use vLLM without having to worry about
-accidentally making deterministic subsequent operations that rely on random state.
+ By default, `VLLM_ENABLE_V1_MULTIPROCESSING=1` so you can use vLLM without having to worry about
+ accidentally making deterministic subsequent operations that rely on random state.
diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md
index 8d8a9e0f50805..5f647aafd61d4 100644
--- a/docs/usage/v1_guide.md
+++ b/docs/usage/v1_guide.md
@@ -2,11 +2,9 @@
!!! announcement
- We have started the process of deprecating V0. Please read [RFC #18571](https://github.com/vllm-project/vllm/issues/18571) for more details.
+ We have fully deprecated V0. Please read [RFC #18571](https://github.com/vllm-project/vllm/issues/18571) for more details.
-V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack).
-
-## Why vLLM V1?
+ If you have a use case that works on V0 Engine but not V1, please share it on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack).
vLLM V0 successfully supported a wide range of models and hardware, but as new features were developed independently, the system grew increasingly complex. This complexity made it harder to integrate new capabilities and introduced technical debt, revealing the need for a more streamlined and unified design.
@@ -32,16 +30,44 @@ Upgrade to vLLM’s Core Architecture](https://blog.vllm.ai/2025/01/27/v1-alpha-
This living user guide outlines a few known **important changes and limitations** introduced by vLLM V1. The team has been working actively to bring V1 as the default engine, therefore this guide will be updated constantly as more features get supported on vLLM V1.
-## Current Status
+## Differences from V0
-For each item, our progress towards V1 support falls into one of the following states:
+This section lists some differences in behavior between V0 and V1.
-- **🚀 Optimized**: Nearly fully optimized, with no further work currently planned.
-- **🟢 Functional**: Fully operational, with ongoing optimizations.
-- **🚧 WIP**: Under active development.
-- **🟡 Planned**: Scheduled for future implementation (some may have open PRs/RFCs).
-- **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later.
-- **🔴 Deprecated**: Not planned for V1 unless there is strong demand.
+### Chunked Prefill
+
+Chunked prefill is enabled by default whenever possible, unlike in V0 where it was conditionally enabled based on model characteristics.
+
+### CUDA Graphs
+
+CUDA graph capture takes up more memory in V1 than in V0.
+
+### Semantic Changes to Logprobs
+
+#### Logprobs Calculation
+
+By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
+before applying any logits post-processing such as temperature scaling or penalty
+adjustments). As a result, the returned logprobs do not reflect the final adjusted
+probabilities used during sampling.
+
+You can adjust this behavior by setting the `--logprobs-mode` flag.
+Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
+Raw means the values before applying any logit processors, like bad words.
+Processed means the values after applying all processors, including temperature and top_k/top_p.
+
+#### Prompt Logprobs with Prefix Caching
+
+While V1 supports passing prompt logprobs with prefix caching enabled, it no longer caches the logprobs.
+For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
+
+## Feature Support
+
+For each item, its support in vLLM V1 falls into one of the following states:
+
+- **🟢 Functional**: Fully operational with optimizations comparable to or better than V0.
+- **🟡 In Progress**: Planned to be in vLLM V1, with open PRs/RFCs.
+- **🔴 Removed**: Dropped from vLLM V1. Will only consider re-introducing if there is strong demand.
!!! note
vLLM V1’s unified scheduler treats both prompt and output tokens the same
@@ -57,13 +83,13 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
### Hardware
-| Hardware | Status |
-|------------|-----------------------------------------------|
-| **NVIDIA** | 🚀 |
-| **AMD** | 🟢 |
+| Hardware | Status |
+|------------------|-----------------------------------------------|
+| **NVIDIA** | 🟢 |
+| **AMD** | 🟢 |
| **INTEL GPU** | 🟢 |
-| **TPU** | 🟢 |
-| **CPU** | 🟢 (x86\_64/aarch64) 🟡 (MacOS) |
+| **TPU** | 🟢 |
+| **CPU** | 🟢 |
!!! note
@@ -78,23 +104,21 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
### Models
-| Model Type | Status |
-|-----------------------------|------------------------------------------------------------------------------------|
-| **Decoder-only Models** | 🚀 Optimized |
-| **Encoder-Decoder Models** | 🟢 Whisper only |
-| **Embedding Models** | 🟢 Functional |
-| **Mamba Models** | 🟢 (Mamba-2), 🟢 (Mamba-1) |
-| **Multimodal Models** | 🟢 Functional |
+| Model Type | Status |
+|-----------------------------|-------------------------------------------------------------------------|
+| **Decoder-only Models** | 🟢 |
+| **Encoder-Decoder Models** | 🟢 (Whisper), 🔴 (Others) |
+| **Pooling Models** | 🟢 |
+| **Mamba Models** | 🟢 |
+| **Multimodal Models** | 🟢 |
See below for the status of models that are not yet supported or have more features planned in V1.
-#### Embedding Models
+#### Pooling Models
-The initial basic support is now functional.
+Now fully supported, with prefix caching and chunked prefill newly available for last-pooling models.
-Later, we will consider using [hidden states processor](https://github.com/vllm-project/vllm/issues/12249),
-which is based on [global logits processor](https://github.com/vllm-project/vllm/pull/13360)
-to enable simultaneous generation and embedding using the same engine instance in V1.
+We are working on enabling prefix caching and chunked prefill for more categories of pooling models.
#### Mamba Models
@@ -112,24 +136,25 @@ Please note that prefix caching is not yet supported for any of the above models
Whisper is supported. Other models requiring cross-attention between separate
encoder and decoder (e.g., `BartForConditionalGeneration`,
-`MllamaForConditionalGeneration`) are not supported.
+`MllamaForConditionalGeneration`) are no longer supported.
### Features
| Feature | Status |
|---------------------------------------------|-----------------------------------------------------------------------------------|
-| **Prefix Caching** | 🚀 Optimized |
-| **Chunked Prefill** | 🚀 Optimized |
-| **LoRA** | 🚀 Optimized |
+| **Prefix Caching** | 🟢 Functional |
+| **Chunked Prefill** | 🟢 Functional |
+| **LoRA** | 🟢 Functional |
| **Logprobs Calculation** | 🟢 Functional |
-| **FP8 KV Cache** | 🟢 Functional on Hopper devices () |
-| **Spec Decode** | 🚀 Optimized |
-| **Prompt Logprobs with Prefix Caching** | 🟡 Planned ([RFC #13414](https://github.com/vllm-project/vllm/issues/13414)) |
+| **FP8 KV Cache** | 🟢 Functional |
+| **Spec Decode** | 🟢 Functional |
+| **Prompt Logprobs with Prefix Caching** | 🟢 Functional |
| **Structured Output Alternative Backends** | 🟢 Functional |
-| **Request-level Structured Output Backend** | 🔴 Deprecated |
-| **best_of** | 🔴 Deprecated ([RFC #13361](https://github.com/vllm-project/vllm/issues/13361)) |
-| **Per-Request Logits Processors** | 🔴 Deprecated ([RFC #13360](https://github.com/vllm-project/vllm/pull/13360)) |
-| **GPU <> CPU KV Cache Swapping** | 🔴 Deprecated |
+| **Concurrent Partial Prefills** | 🟡 [In Progress](https://github.com/vllm-project/vllm/issues/14003) |
+| **best_of** | 🔴 [Removed](https://github.com/vllm-project/vllm/issues/13361) |
+| **Per-Request Logits Processors** | 🔴 [Removed](https://github.com/vllm-project/vllm/pull/13360) |
+| **GPU <> CPU KV Cache Swapping** | 🔴 Removed |
+| **Request-level Structured Output Backend** | 🔴 Removed |
!!! note
@@ -139,38 +164,17 @@ encoder and decoder (e.g., `BartForConditionalGeneration`,
prefix caching, and speculative decoding without a strict separation between prefill
and decode phases.
-#### Semantic Changes to Logprobs
+#### Removed Features
-vLLM V1 supports logprobs and prompt logprobs. However, there are some important semantic
-differences compared to V0:
-
-##### Logprobs Calculation
-
-By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
-before applying any logits post-processing such as temperature scaling or penalty
-adjustments). As a result, the returned logprobs do not reflect the final adjusted
-probabilities used during sampling.
-
-You can adjust this behavior by setting the `--logprobs-mode` flag.
-Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
-Raw means the values before applying any logit processors, like bad words.
-Processed means the values after applying all processors, including temperature and top_k/top_p.
-
-##### Prompt Logprobs with Prefix Caching
-
-Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
-
-#### Deprecated Features
-
-As part of the major architectural rework in vLLM V1, several legacy features have been deprecated.
+As part of the major architectural rework in vLLM V1, several legacy features have been removed.
##### Sampling features
-- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361).
+- **best_of**: This feature has been removed due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361).
- **Per-Request Logits Processors**: In V0, users could pass custom
processing functions to adjust logits on a per-request basis. In vLLM V1, this
- feature has been deprecated. Instead, the design is moving toward supporting **global logits
- processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](https://github.com/vllm-project/vllm/pull/13360).
+ feature has been removed. Instead, we now support **global logits processors**
+ which are set at startup time, see [RFC #17799](https://github.com/vllm-project/vllm/issues/17799).
##### KV Cache features
@@ -179,4 +183,4 @@ to handle request preemptions.
##### Structured Output features
-- **Request-level Structured Output Backend**: Deprecated, alternative backends (outlines, guidance) with fallbacks is supported now.
+- **Request-level Structured Output Backend**: Removed; alternative backends (outlines, guidance) with fallbacks are supported now.
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
old mode 100644
new mode 100755
index 04e6f99f8957e..df6e96ca375fc
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -425,6 +425,13 @@ def parse_args():
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -434,6 +441,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
+ if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {args.tensor_parallel_size}"
+ )
+
audio_count = args.num_audios
req_data = model_example_map[model](
question_per_audio_count[audio_count], audio_count
@@ -446,6 +459,8 @@ def main(args):
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ if args.tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different
diff --git a/examples/offline_inference/context_extension.py b/examples/offline_inference/context_extension.py
index df39e4c25d5c8..67d33e1881ee9 100644
--- a/examples/offline_inference/context_extension.py
+++ b/examples/offline_inference/context_extension.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script demonstrates how to extend the context length
-of a Qwen model using the YARN method (rope_scaling)
+of a Qwen model using the YARN method (rope_parameters)
and run a simple chat example.
Usage:
@@ -19,8 +19,8 @@ def create_llm():
# Use yarn to extend context
hf_overrides = {
- "rope_theta": rope_theta,
- "rope_scaling": {
+ "rope_parameters": {
+ "rope_theta": rope_theta,
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
diff --git a/examples/offline_inference/profiling_tpu/README.md b/examples/offline_inference/profiling_tpu/README.md
deleted file mode 100644
index 8c9c1c92b6764..0000000000000
--- a/examples/offline_inference/profiling_tpu/README.md
+++ /dev/null
@@ -1,70 +0,0 @@
-# vLLM TPU Profiling
-
-This script is used to profile the TPU performance of vLLM for specific prefill or decode token shapes.
-
-Note: an actual running server is a mix of both prefill of many shapes and decode of many shapes.
-
-We assume you are on a TPU already (this was tested on TPU v6e) and have installed vLLM according to the [Google TPU installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/google_tpu.html).
-
-> In all examples below, we run several warmups before (so `--enforce-eager` is okay)
-
-## Profile Examples
-
-### Generate Prefill Trace
-
-This example runs Qwen/Qwen2.5-7B-Instruct with a single request of 1024 input tokens. This is set up in attempt to profile just the prefill time and operations.
-
-```bash
-export XLA_HLO_DEBUG=1
-export MODEL=Qwen/Qwen2.5-7B-Instruct
-export VLLM_TPU_PROFILE_DURATION_MS=3000
-export VLLM_TPU_PROFILE_DELAY_MS=0
-
-python3 profiling.py \
- --model $MODEL \
- --input-len 1024 --output-len 1 \
- --batch-size 1 --enforce-eager \
- --max-model-len 2048 \
- --tensor-parallel-size 1 \
- --profile-result-dir profiles
-```
-
-### Generate Decode Trace
-
-This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting `VLLM_TPU_PROFILE_DELAY_MS=1000` to skip the first second of inference (hopefully prefill).
-
-```bash
-export XLA_HLO_DEBUG=1
-export MODEL=meta-llama/Llama-3.1-70B-Instruct
-export VLLM_TPU_PROFILE_DURATION_MS=2000
-export VLLM_TPU_PROFILE_DELAY_MS=1000
-
-rm -rf ~/.cache/vllm/xla_cache
-python3 profiling.py \
- --model $MODEL \
- --input-len 1 \
- --output-len 128 \
- --batch-size 32 \
- --enforce-eager \
- --profile-result-dir profiles \
- --max-model-len 2048 --tensor-parallel-size 8
-```
-
-## Visualizing the profiles
-
-Once you have collected your profiles with this script, you can visualize them using [TensorBoard](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).
-
-Here are most likely the dependencies you need to install:
-
-```bash
-pip install tensorflow-cpu \
- tensorboard-plugin-profile \
- etils \
- importlib_resources
-```
-
-Then you just need to point TensorBoard to the directory where you saved the profiles and visit `http://localhost:6006/` in your browser:
-
-```bash
-tensorboard --logdir profiles/ --port 6006
-```
diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py
deleted file mode 100644
index 3b127e4fd29df..0000000000000
--- a/examples/offline_inference/profiling_tpu/profiling.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import argparse
-import dataclasses
-import os
-import time
-
-import numpy as np
-import torch_xla.debug.profiler as xp
-from tqdm import tqdm
-
-from vllm import LLM, SamplingParams
-from vllm.engine.arg_utils import EngineArgs
-from vllm.inputs import PromptType
-from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000))
-DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0))
-
-
-def main(args: argparse.Namespace):
- print(args)
-
- engine_args = EngineArgs.from_cli_args(args)
- llm = LLM(**dataclasses.asdict(engine_args))
- server = xp.start_server(9012) # noqa: F841
-
- sampling_params = SamplingParams(
- temperature=0.0,
- ignore_eos=True,
- max_tokens=args.output_len,
- )
- print(sampling_params)
- dummy_prompt_token_ids = np.random.randint(
- 10000, size=(args.batch_size, args.input_len)
- )
- dummy_prompts: list[PromptType] = [
- {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
- ]
-
- def run_to_completion():
- start_time = time.perf_counter()
- llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
- end_time = time.perf_counter()
- latency = end_time - start_time
- return latency
-
- # Warmup
- print("Warming up...")
- warmup_latencies = []
- for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
- warmup_latencies.append(run_to_completion())
- print(f"Average warmup latency: {np.mean(warmup_latencies):.4f}s")
-
- # Profile
- profile_dir = args.profile_result_dir
- print(f"Profiling (results will be saved to '{profile_dir}')...")
- # Enable tracing on server
- xp.trace_detached(
- "localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS
- )
- if DELAY_MS == 0:
- time.sleep(1.0)
- profile_latencies = []
- for _ in tqdm(range(args.num_iters), desc="Profile iterations"):
- profile_latencies.append(run_to_completion())
- print(f"Average profile latency: {np.mean(profile_latencies):.4f}s")
-
- return
-
-
-def parse_args():
- parser = FlexibleArgumentParser(
- description="Benchmark the latency of processing a single batch of "
- "requests till completion."
- )
- parser.add_argument("--input-len", type=int, default=32)
- parser.add_argument("--output-len", type=int, default=128)
- parser.add_argument("--batch-size", type=int, default=8)
- parser.add_argument(
- "--num-iters-warmup",
- type=int,
- default=5,
- help="Number of iterations to run for warmup.",
- )
- parser.add_argument(
- "--num-iters",
- type=int,
- default=1,
- help="Number of iterations to run for profiling.",
- )
- parser.add_argument(
- "--profile-result-dir",
- type=str,
- default="profiles",
- help=(
- "path to save the pytorch profiler output. Can be visualized "
- "with ui.perfetto.dev or Tensorboard "
- "(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
- ),
- )
-
- parser = EngineArgs.add_cli_args(parser)
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- args = parse_args()
- main(args)
diff --git a/examples/offline_inference/qwen3_omni/only_thinker.py b/examples/offline_inference/qwen3_omni/only_thinker.py
new file mode 100644
index 0000000000000..88a61ed694c2e
--- /dev/null
+++ b/examples/offline_inference/qwen3_omni/only_thinker.py
@@ -0,0 +1,170 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+This example shows how to use vLLM for running offline inference
+with the correct prompt format on Qwen2.5-Omni (thinker only).
+"""
+
+from typing import NamedTuple
+
+from vllm import LLM, SamplingParams
+from vllm.assets.audio import AudioAsset
+from vllm.assets.image import ImageAsset
+from vllm.assets.video import VideoAsset
+from vllm.multimodal.image import convert_image_mode
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+
+class QueryResult(NamedTuple):
+ inputs: dict
+ limit_mm_per_prompt: dict[str, int]
+
+
+# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
+# lower-end GPUs.
+# Unless specified, these settings have been tested to work on a single L4.
+
+default_system = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
+ "Group, capable of perceiving auditory and visual inputs, as well as "
+ "generating text and speech."
+)
+
+
+def get_mixed_modalities_query() -> QueryResult:
+ question = (
+ "What is recited in the audio? "
+ "What is the content of this image? Why is this video funny?"
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
+ "<|vision_start|><|image_pad|><|vision_end|>"
+ "<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ "image": convert_image_mode(
+ ImageAsset("cherry_blossom").pil_image, "RGB"
+ ),
+ "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
+ )
+
+
+def get_use_audio_in_video_query() -> QueryResult:
+ question = (
+ "Describe the content of the video in details, then convert what the "
+ "baby say into text."
+ )
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ asset = VideoAsset(name="baby_reading", num_frames=16)
+ audio = asset.get_audio(sampling_rate=16000)
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "video": asset.np_ndarrays,
+ "audio": audio,
+ },
+ "mm_processor_kwargs": {
+ "use_audio_in_video": True,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "video": 1},
+ )
+
+
+def get_multi_audios_query() -> QueryResult:
+ question = "Are these two audio clips the same?"
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
+ "<|audio_start|><|audio_pad|><|audio_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "audio": [
+ AudioAsset("winning_call").audio_and_sample_rate,
+ AudioAsset("mary_had_lamb").audio_and_sample_rate,
+ ],
+ },
+ },
+ limit_mm_per_prompt={
+ "audio": 2,
+ },
+ )
+
+
+query_map = {
+ "mixed_modalities": get_mixed_modalities_query,
+ "use_audio_in_video": get_use_audio_in_video_query,
+ "multi_audios": get_multi_audios_query,
+}
+
+
+def main(args):
+ model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+ query_result = query_map[args.query_type]()
+
+ llm = LLM(
+ model=model_name,
+ max_model_len=12800,
+ max_num_seqs=5,
+ limit_mm_per_prompt=query_result.limit_mm_per_prompt,
+ seed=args.seed,
+ )
+
+ # We set temperature to 0.2 so that outputs can be different
+ # even when all prompts are identical when running batch inference.
+ sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
+
+ outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
+
+ for o in outputs:
+ generated_text = o.outputs[0].text
+ print(generated_text)
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(
+ description="Demo on using vLLM for offline inference with "
+ "audio language models"
+ )
+ parser.add_argument(
+ "--query-type",
+ "-q",
+ type=str,
+ default="mixed_modalities",
+ choices=query_map.keys(),
+ help="Query type.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Set the seed when initializing `vllm.LLM`.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/reproducibility.py b/examples/offline_inference/reproducibility.py
index d909438b41042..72c1e841dca45 100644
--- a/examples/offline_inference/reproducibility.py
+++ b/examples/offline_inference/reproducibility.py
@@ -11,12 +11,11 @@ import random
from vllm import LLM, SamplingParams
-# V1 only: Turn off multiprocessing to make the scheduling deterministic.
+# Either:
+## Turn off multiprocessing to make the scheduling deterministic, or
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
-
-# V0 only: Set the global seed. The default seed is None, which is
-# not reproducible.
-SEED = 42
+## Enable batch invariance to get consistent results regardless of scheduling.
+os.environ["VLLM_BATCH_INVARIANT"] = "1"
prompts = [
"Hello, my name is",
@@ -28,7 +27,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
- llm = LLM(model="facebook/opt-125m", seed=SEED)
+ llm = LLM(model="facebook/opt-125m")
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output in outputs:
diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py
index 13def88439ef2..5c0787b8778d6 100644
--- a/examples/offline_inference/rlhf_utils.py
+++ b/examples/offline_inference/rlhf_utils.py
@@ -30,8 +30,8 @@ class WorkerExtension:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
- the underlying worker class. This way, the code can be compatible
- with both vLLM V0 and V1.
+ the underlying worker class.
+
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
@@ -96,8 +96,8 @@ class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
- the underlying worker class. This way, the code can be compatible
- with both vLLM V0 and V1.
+ the underlying worker class.
+
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py
index e25f46b126e6f..88ee48b98bff6 100644
--- a/examples/offline_inference/save_sharded_state.py
+++ b/examples/offline_inference/save_sharded_state.py
@@ -67,22 +67,9 @@ def main(args):
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
- # Check which engine version is being used
- is_v1_engine = hasattr(llm.llm_engine, "engine_core")
-
- if is_v1_engine:
- # For V1 engine, we need to use engine_core.save_sharded_state
- print("Using V1 engine save path")
- llm.llm_engine.engine_core.save_sharded_state(
- path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
- )
- else:
- # For V0 engine
- print("Using V0 engine save path")
- model_executor = llm.llm_engine.model_executor
- model_executor.save_sharded_state(
- path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
- )
+ llm.llm_engine.engine_core.save_sharded_state(
+ path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
+ )
# Copy metadata files to output directory
for file in os.listdir(model_path):
diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py
index 3cdc3b245b72a..67a0732459709 100644
--- a/examples/offline_inference/spec_decode.py
+++ b/examples/offline_inference/spec_decode.py
@@ -158,11 +158,7 @@ def main(args):
print(f"generated text: {output.outputs[0].text}")
print("-" * 50)
- try:
- metrics = llm.get_metrics()
- except AssertionError:
- print("Metrics are not supported in the V0 engine.")
- return
+ metrics = llm.get_metrics()
total_num_output_tokens = sum(
len(output.outputs[0].token_ids) for output in outputs
diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py
deleted file mode 100644
index 0093b63b0b1f3..0000000000000
--- a/examples/offline_inference/tpu.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import argparse
-import os
-
-from vllm import LLM, SamplingParams
-
-prompts = [
- "A robot may not injure a human being",
- "It is only with the heart that one can see rightly;",
- "The greatest glory in living lies not in never falling,",
-]
-answers = [
- " or, through inaction, allow a human being to come to harm.",
- " what is essential is invisible to the eye.",
- " but in rising every time we fall.",
-]
-N = 1
-# Currently, top-p sampling is disabled. `top_p` should be 1.0.
-sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
-
-
-def main():
- parser = argparse.ArgumentParser(description="TPU offline inference example")
- parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode")
- args = parser.parse_args()
-
- llm_args = {
- "model": "Qwen/Qwen2-1.5B-Instruct",
- "max_num_batched_tokens": 64,
- "max_num_seqs": 4,
- "max_model_len": 128,
- }
- if args.use_spmd:
- os.environ["VLLM_XLA_USE_SPMD"] = "1"
- # Can only hardcode the number of chips for now.
- # calling xr.global_runtime_device_count() beforeing init SPMD env in
- # torch_xla will mess up the distributed env.
- llm_args["tensor_parallel_size"] = 8
- # Use Llama, for num_kv_heads = 8.
- llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
-
- # Set `enforce_eager=True` to avoid ahead-of-time compilation.
- # In real workloads, `enforce_eager` should be `False`.
- llm = LLM(**llm_args)
- outputs = llm.generate(prompts, sampling_params)
- print("-" * 50)
- for output, answer in zip(outputs, answers):
- prompt = output.prompt
- generated_text = output.outputs[0].text
- print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
- assert generated_text.startswith(answer)
- print("-" * 50)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
old mode 100644
new mode 100755
index 624de2a2debc3..8f72bf6f0b0d1
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
)
+# HunyuanOCR
+def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "tencent/HunyuanOCR"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+ prompts = [
+ f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>"
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=None,
+ )
+
+
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
def run_hyperclovax_seed_vision(
questions: list[str], modality: str
@@ -1820,6 +1845,7 @@ model_example_map = {
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"h2ovl_chat": run_h2ovl,
+ "hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3,
"interns1": run_interns1,
@@ -2038,6 +2064,13 @@ def parse_args():
help="If True, will send all requests in a second batch with empty mm "
"data to verify cache hits with UUIDs.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -2046,6 +2079,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
+ if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {args.tensor_parallel_size}"
+ )
+
modality = args.modality
mm_input = get_multi_modal_input(args)
data = mm_input["data"]
@@ -2063,6 +2102,8 @@ def main(args):
"seed": args.seed,
"mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
}
+ if args.tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`.
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
old mode 100644
new mode 100755
index d6e169548f15b..7ba4e64b567de
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -1110,6 +1110,7 @@ def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model=model_name,
max_model_len=16384,
max_num_seqs=16,
+ trust_remote_code=True,
limit_mm_per_prompt={"image": len(image_urls)},
)
@@ -1351,10 +1352,18 @@ model_example_map = {
}
-def run_generate(model, question: str, image_urls: list[str], seed: int | None):
+def run_generate(
+ model,
+ question: str,
+ image_urls: list[str],
+ seed: int | None,
+ tensor_parallel_size: int | None,
+):
req_data = model_example_map[model](question, image_urls)
- engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
+ engine_args = asdict(req_data.engine_args) | {"seed": seed}
+ if tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = tensor_parallel_size
llm = LLM(**engine_args)
sampling_params = SamplingParams(
@@ -1377,7 +1386,13 @@ def run_generate(model, question: str, image_urls: list[str], seed: int | None):
print("-" * 50)
-def run_chat(model: str, question: str, image_urls: list[str], seed: int | None):
+def run_chat(
+ model: str,
+ question: str,
+ image_urls: list[str],
+ seed: int | None,
+ tensor_parallel_size: int | None,
+):
req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory
@@ -1387,6 +1402,8 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
+ if tensor_parallel_size is not None:
+ engine_args["tensor_parallel_size"] = tensor_parallel_size
llm = LLM(**engine_args)
sampling_params = (
@@ -1462,6 +1479,13 @@ def parse_args():
default=2,
help="Number of images to use for the demo.",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ "-tp",
+ type=int,
+ default=None,
+ help="Tensor parallel size to override the model's default setting. ",
+ )
return parser.parse_args()
@@ -1469,13 +1493,20 @@ def main(args: Namespace):
model = args.model_type
method = args.method
seed = args.seed
+ tensor_parallel_size = args.tensor_parallel_size
+
+ if tensor_parallel_size is not None and tensor_parallel_size < 1:
+ raise ValueError(
+ f"tensor_parallel_size must be a positive integer, "
+ f"got {tensor_parallel_size}"
+ )
image_urls = IMAGE_URLS[: args.num_images]
if method == "generate":
- run_generate(model, QUESTION, image_urls, seed)
+ run_generate(model, QUESTION, image_urls, seed, tensor_parallel_size)
elif method == "chat":
- run_chat(model, QUESTION, image_urls, seed)
+ run_chat(model, QUESTION, image_urls, seed, tensor_parallel_size)
else:
raise ValueError(f"Invalid method: {method}")
diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh
index d434e22b1ae88..cd2f2e44a4d69 100644
--- a/examples/online_serving/disaggregated_prefill.sh
+++ b/examples/online_serving/disaggregated_prefill.sh
@@ -24,7 +24,14 @@ cleanup() {
exit 0
}
-export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
+
+if [[ -z "${VLLM_HOST_IP:-}" ]]; then
+ export VLLM_HOST_IP=127.0.0.1
+ echo "Using default VLLM_HOST_IP=127.0.0.1 (override by exporting VLLM_HOST_IP before running this script)"
+else
+ echo "Using provided VLLM_HOST_IP=${VLLM_HOST_IP}"
+fi
+
# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
@@ -38,7 +45,7 @@ fi
wait_for_server() {
local port=$1
timeout 1200 bash -c "
- until curl -s localhost:${port}/v1/completions > /dev/null; do
+ until curl -i localhost:${port}/v1/models > /dev/null; do
sleep 1
done" && return 0 || return 1
}
@@ -48,21 +55,23 @@ wait_for_server() {
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
+ --host 0.0.0.0 \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-transfer-config \
- '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
+ '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":"1e9","kv_port":"14579","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8100","send_type":"PUT_ASYNC"}}' &
-# decoding instance, which is the KV consumer
+# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
+ --host 0.0.0.0 \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-transfer-config \
- '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
+ '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":"1e10","kv_port":"14580","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8200","send_type":"PUT_ASYNC"}}' &
# wait until prefill and decode instances are ready
wait_for_server 8100
diff --git a/examples/online_serving/gradio_openai_chatbot_webserver.py b/examples/online_serving/gradio_openai_chatbot_webserver.py
index d5d0a07a29183..c76c60cc4472d 100644
--- a/examples/online_serving/gradio_openai_chatbot_webserver.py
+++ b/examples/online_serving/gradio_openai_chatbot_webserver.py
@@ -25,25 +25,17 @@ import gradio as gr
from openai import OpenAI
-def format_history_to_openai(history):
- history_openai_format = [
- {"role": "system", "content": "You are a great AI assistant."}
- ]
- for human, assistant in history:
- history_openai_format.append({"role": "user", "content": human})
- history_openai_format.append({"role": "assistant", "content": assistant})
- return history_openai_format
-
-
def predict(message, history, client, model_name, temp, stop_token_ids):
- # Format history to OpenAI chat format
- history_openai_format = format_history_to_openai(history)
- history_openai_format.append({"role": "user", "content": message})
+ messages = [
+ {"role": "system", "content": "You are a great AI assistant."},
+ *history,
+ {"role": "user", "content": message},
+ ]
# Send request to OpenAI API (vLLM server)
stream = client.chat.completions.create(
model=model_name,
- messages=history_openai_format,
+ messages=messages,
temperature=temp,
stream=True,
extra_body={
diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh
index 1577de85f7ff2..b5c92749466b0 100644
--- a/examples/online_serving/openai_embedding_long_text/service.sh
+++ b/examples/online_serving/openai_embedding_long_text/service.sh
@@ -22,7 +22,6 @@ API_KEY=${API_KEY:-"your-api-key"}
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
export VLLM_ENABLE_CHUNKED_PROCESSING=true
export CUDA_VISIBLE_DEVICES=2,3,4,5
-# export VLLM_ATTENTION_BACKEND=XFORMERS
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
echo "=================================================================="
diff --git a/examples/online_serving/openai_responses_client.py b/examples/online_serving/openai_responses_client.py
new file mode 100644
index 0000000000000..b4eb24671507a
--- /dev/null
+++ b/examples/online_serving/openai_responses_client.py
@@ -0,0 +1,44 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Set up this example by starting a vLLM OpenAI-compatible server.
+Reasoning models can be used through the Responses API as seen here
+https://platform.openai.com/docs/api-reference/responses
+For example:
+vllm serve Qwen/Qwen3-8B --reasoning-parser qwen3
+
+"""
+
+from openai import OpenAI
+
+input_messages = [{"role": "user", "content": "What model are you?"}]
+
+
+def main():
+ base_url = "http://localhost:8000/v1"
+ client = OpenAI(base_url=base_url, api_key="empty")
+ model = "Qwen/Qwen3-8B" # get_first_model(client)
+ response = client.responses.create(
+ model=model,
+ input=input_messages,
+ )
+
+ for message in response.output:
+ if message.type == "reasoning":
+ # append reasoning message
+ input_messages.append(message)
+
+ response_2 = client.responses.create(
+ model=model,
+ input=input_messages,
+ )
+ print(response_2.output_text)
+ # I am Qwen, a large language model developed by Alibaba Cloud.
+ # I am designed to assist with a wide range of tasks, including
+ # answering questions, creating content, coding, and engaging in
+ # conversations. I can help with various topics and provide
+ # information or support in multiple languages. How can I assist you today?
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/online_serving/prometheus_grafana/README.md b/examples/online_serving/prometheus_grafana/README.md
index 5cd4dab5a8fa7..9615210a2ad80 100644
--- a/examples/online_serving/prometheus_grafana/README.md
+++ b/examples/online_serving/prometheus_grafana/README.md
@@ -46,7 +46,7 @@ Navigate to [`http://localhost:3000`](http://localhost:3000). Log in with the de
Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus.
-On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each containers. You can just use `http://prometheus:9090`.
+On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each container. You can just use `http://prometheus:9090`.
Click `Save & Test`. You should get a green check saying "Successfully queried the Prometheus API.".
diff --git a/requirements/common.txt b/requirements/common.txt
index 1058ab91a02a5..3f8cd588422d0 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -19,12 +19,12 @@ pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.11.3
-llguidance >= 1.3.0, < 1.4.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" or platform_machine == "s390x"
+llguidance >= 1.3.0, < 1.4.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" or platform_machine == "s390x" or platform_machine == "ppc64le"
outlines_core == 0.2.11
# required for outlines backend disk cache
diskcache == 5.6.3
lark == 1.2.2
-xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x"
+xgrammar == 0.1.27; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
diff --git a/requirements/cuda.txt b/requirements/cuda.txt
index d63fe9e1e77c1..15e8aadc56f47 100644
--- a/requirements/cuda.txt
+++ b/requirements/cuda.txt
@@ -9,6 +9,5 @@ torch==2.9.0
torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
-xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.2
diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt
index b1f3269cd3813..083230c171096 100644
--- a/requirements/kv_connectors.txt
+++ b/requirements/kv_connectors.txt
@@ -1,2 +1,2 @@
lmcache
-nixl >= 0.6.0 # Required for disaggregated prefill
+nixl >= 0.7.1 # Required for disaggregated prefill
diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt
index 432e11977872d..8a91b59de6f72 100644
--- a/requirements/rocm-test.txt
+++ b/requirements/rocm-test.txt
@@ -39,3 +39,13 @@ mteb[bm25s]>=1.38.11, <2
# Required for eval tests
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
+
+# Required for multiprocessed tests that use spawn method
+multiprocess==0.70.16
+
+# Plugins test
+terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
+torchgeo==0.7.0
+
+# Required for suffix decoding test
+arctic-inference == 0.1.1
diff --git a/requirements/rocm.txt b/requirements/rocm.txt
index 6f1cca90e5e2b..abbd33d6e1240 100644
--- a/requirements/rocm.txt
+++ b/requirements/rocm.txt
@@ -15,3 +15,4 @@ setuptools-scm>=8
runai-model-streamer[s3,gcs]==0.15.0
conch-triton-kernels==1.2.1
timm>=1.0.17
+fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@d6f998a03432b2452f8de2bb5cefb5af9795d459
diff --git a/requirements/test.in b/requirements/test.in
index 30d97e9b9c7d0..05f6bcca5c2c4 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -36,7 +36,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
# TODO: Use lm-eval[api]==0.4.10 once released
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
-mteb[bm25s]>=1.38.11, <2 # required for mteb test
+mteb[bm25s]>=2, <3 # required for mteb test
transformers==4.57.1
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
diff --git a/requirements/test.txt b/requirements/test.txt
index 3263b74c08797..bcd511660f85e 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -201,8 +201,6 @@ email-validator==2.2.0
# via pydantic
encodec==0.1.1
# via vocos
-eval-type-backport==0.2.2
- # via mteb
evaluate==0.4.3
# via lm-eval
fastapi==0.116.1
@@ -490,7 +488,7 @@ msgpack==1.1.0
# via
# librosa
# ray
-mteb==1.38.11
+mteb==2.1.2
# via -r requirements/test.in
multidict==6.1.0
# via
diff --git a/requirements/xpu.txt b/requirements/xpu.txt
index 59ea710684a2c..c1dc4195b5231 100644
--- a/requirements/xpu.txt
+++ b/requirements/xpu.txt
@@ -10,9 +10,9 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.61.2 # Required for N-gram speculative decoding
-torch==2.8.0+xpu
+--extra-index-url=https://download.pytorch.org/whl/xpu
+torch==2.9.0+xpu
torchaudio
torchvision
---extra-index-url=https://download.pytorch.org/whl/xpu
-intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post1%2Bxpu-cp312-cp312-linux_x86_64.whl
+intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.9.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl
diff --git a/setup.py b/setup.py
index 5591bcb132447..8871b04d8fc46 100644
--- a/setup.py
+++ b/setup.py
@@ -74,18 +74,6 @@ def is_ninja_available() -> bool:
return which("ninja") is not None
-def is_url_available(url: str) -> bool:
- from urllib.request import urlopen
-
- status = None
- try:
- with urlopen(url) as f:
- status = f.status
- except Exception:
- return False
- return status == 200
-
-
class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
super().__init__(name, sources=[], py_limited_api=True, **kwa)
@@ -533,28 +521,6 @@ def get_nvcc_cuda_version() -> Version:
return nvcc_cuda_version
-def get_gaudi_sw_version():
- """
- Returns the driver version.
- """
- # Enable console printing for `hl-smi` check
- output = subprocess.run(
- "hl-smi",
- shell=True,
- text=True,
- capture_output=True,
- env={"ENABLE_CONSOLE": "true"},
- )
- if output.returncode == 0 and output.stdout:
- return (
- output.stdout.split("\n")[2]
- .replace(" ", "")
- .split(":")[1][:-1]
- .split("-")[0]
- )
- return "0.0.0" # when hl-smi is not available
-
-
def get_vllm_version() -> str:
# Allow overriding the version. This is useful to build platform-specific
# wheels (e.g. CPU, TPU) without modifying the source.
diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py
index 0cf1e85d4e8ee..521d6c33dd390 100644
--- a/tests/basic_correctness/test_basic_correctness.py
+++ b/tests/basic_correctness/test_basic_correctness.py
@@ -74,9 +74,6 @@ def test_models(
model_executor: str,
enable_prompt_embeds: bool,
) -> None:
- if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
- pytest.skip(f"{backend} does not support gemma2 with full context length.")
-
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend)
diff --git a/tests/compile/README.md b/tests/compile/README.md
new file mode 100644
index 0000000000000..300a956860005
--- /dev/null
+++ b/tests/compile/README.md
@@ -0,0 +1,5 @@
+# compile test folder structure
+
+- `compile/test_*.py` : various unit tests meant for testing particular code path/features. Future tests are most likely added here. New test files added here will be included in CI automatically
+- `compile/fullgraph/` : full model tests, including all tests previously in compile/piecewise. These tests do not target particular features. New test files added here will be included in CI automatically
+- `compile/distributed/` : tests that require multiple GPUs. New test files added here will **NOT** be included in CI automatically as these tests generally need to be manually configured to run in runners with particular number/type of GPUs.
diff --git a/tests/compile/piecewise/__init__.py b/tests/compile/distributed/__init__.py
similarity index 100%
rename from tests/compile/piecewise/__init__.py
rename to tests/compile/distributed/__init__.py
diff --git a/tests/compile/test_async_tp.py b/tests/compile/distributed/test_async_tp.py
similarity index 99%
rename from tests/compile/test_async_tp.py
rename to tests/compile/distributed/test_async_tp.py
index 71ee228781438..86d409f1eadb0 100644
--- a/tests/compile/test_async_tp.py
+++ b/tests/compile/distributed/test_async_tp.py
@@ -27,13 +27,13 @@ from vllm.distributed.parallel_state import (
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
-from ..models.registry import HF_EXAMPLE_MODELS
-from ..utils import (
+from ...models.registry import HF_EXAMPLE_MODELS
+from ...utils import (
compare_two_settings,
create_new_process_for_each_test,
multi_gpu_test,
)
-from .backend import TestBackend
+from ..backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/distributed/test_fusion_all_reduce.py
similarity index 99%
rename from tests/compile/test_fusion_all_reduce.py
rename to tests/compile/distributed/test_fusion_all_reduce.py
index 6d0a0ed7d89d2..d401d57032752 100644
--- a/tests/compile/test_fusion_all_reduce.py
+++ b/tests/compile/distributed/test_fusion_all_reduce.py
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
-from ..utils import has_module_attribute, multi_gpu_test
-from .backend import TestBackend
+from ...utils import has_module_attribute, multi_gpu_test
+from ..backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module):
diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py
similarity index 97%
rename from tests/compile/test_fusions_e2e.py
rename to tests/compile/distributed/test_fusions_e2e.py
index f22d60ef000b2..53c3f875d2003 100644
--- a/tests/compile/test_fusions_e2e.py
+++ b/tests/compile/distributed/test_fusions_e2e.py
@@ -18,7 +18,7 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
-from ..utils import flat_product, multi_gpu_test
+from ...utils import flat_product, multi_gpu_test
is_blackwell = lambda: current_platform.is_device_capability(100)
"""Are we running on Blackwell, a lot of tests depend on it"""
@@ -47,12 +47,8 @@ if current_platform.is_cuda():
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
- # TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
- # so FI attention+fp8_quant is at least tested once
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
- backend=AttentionBackendEnum.FLASHINFER
- if is_blackwell()
- else AttentionBackendEnum.TRITON_ATTN,
+ backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
@@ -65,9 +61,9 @@ if current_platform.is_cuda():
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
# https://github.com/vllm-project/vllm/issues/28568
- # TODO FlashInfer attn broken on Blackwell for llama4:
- # https://github.com/vllm-project/vllm/issues/28604
- backend=AttentionBackendEnum.TRITON_ATTN,
+ backend=AttentionBackendEnum.FLASHINFER
+ if is_blackwell()
+ else AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
@@ -115,6 +111,17 @@ if current_platform.is_cuda():
async_tp=96, # MLP is MoE, half the fusions of dense
),
),
+ ModelBackendTestCase(
+ model_name="openai/gpt-oss-20b",
+ model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
+ backend=AttentionBackendEnum.FLASHINFER,
+ matches=Matches(
+ attention_fusion=0,
+ allreduce_fusion=49,
+ sequence_parallel=49,
+ async_tp=48,
+ ),
+ ),
]
elif current_platform.is_rocm():
diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/distributed/test_sequence_parallelism.py
similarity index 99%
rename from tests/compile/test_sequence_parallelism.py
rename to tests/compile/distributed/test_sequence_parallelism.py
index 9cd7f64b04af5..30084dfd5a950 100644
--- a/tests/compile/test_sequence_parallelism.py
+++ b/tests/compile/distributed/test_sequence_parallelism.py
@@ -32,8 +32,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
-from ..utils import multi_gpu_test
-from .backend import TestBackend
+from ...utils import multi_gpu_test
+from ..backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
diff --git a/tests/entrypoints/pooling/correctness/__init__.py b/tests/compile/fullgraph/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/correctness/__init__.py
rename to tests/compile/fullgraph/__init__.py
diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/fullgraph/test_basic_correctness.py
similarity index 99%
rename from tests/compile/test_basic_correctness.py
rename to tests/compile/fullgraph/test_basic_correctness.py
index 3f6898607f6b9..965938c4433dd 100644
--- a/tests/compile/test_basic_correctness.py
+++ b/tests/compile/fullgraph/test_basic_correctness.py
@@ -7,7 +7,7 @@ import pytest
from vllm.config import CompilationMode
from vllm.utils.torch_utils import cuda_device_count_stateless
-from ..utils import compare_all_settings
+from ...utils import compare_all_settings
@dataclasses.dataclass
diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/fullgraph/test_full_cudagraph.py
similarity index 100%
rename from tests/compile/piecewise/test_full_cudagraph.py
rename to tests/compile/fullgraph/test_full_cudagraph.py
diff --git a/tests/compile/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py
similarity index 99%
rename from tests/compile/test_full_graph.py
rename to tests/compile/fullgraph/test_full_graph.py
index b4e5e56ac9fe6..2c11ecef7f029 100644
--- a/tests/compile/test_full_graph.py
+++ b/tests/compile/fullgraph/test_full_graph.py
@@ -15,7 +15,7 @@ from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassC
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
-from ..utils import create_new_process_for_each_test
+from ...utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: list[str] | None = None):
diff --git a/tests/compile/test_multimodal_compile.py b/tests/compile/fullgraph/test_multimodal_compile.py
similarity index 100%
rename from tests/compile/test_multimodal_compile.py
rename to tests/compile/fullgraph/test_multimodal_compile.py
diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/fullgraph/test_multiple_graphs.py
similarity index 100%
rename from tests/compile/piecewise/test_multiple_graphs.py
rename to tests/compile/fullgraph/test_multiple_graphs.py
diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/fullgraph/test_simple.py
similarity index 100%
rename from tests/compile/piecewise/test_simple.py
rename to tests/compile/fullgraph/test_simple.py
diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/fullgraph/test_toy_llama.py
similarity index 100%
rename from tests/compile/piecewise/test_toy_llama.py
rename to tests/compile/fullgraph/test_toy_llama.py
diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py
new file mode 100644
index 0000000000000..c20aea822fe81
--- /dev/null
+++ b/tests/compile/test_dynamic_shapes_compilation.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import gc
+
+import pytest
+import torch
+
+from vllm import LLM, SamplingParams
+from vllm.config.compilation import CompilationMode, DynamicShapesType
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from vllm.utils.torch_utils import is_torch_equal_or_newer
+
+
+def get_test_models():
+ """Get list of models to test based on PyTorch version"""
+ # TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
+ return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
+
+
+@pytest.mark.parametrize("model_name", get_test_models())
+@pytest.mark.parametrize(
+ "shapes_type",
+ [
+ DynamicShapesType.BACKED,
+ DynamicShapesType.UNBACKED,
+ DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
+ ],
+)
+@pytest.mark.parametrize("use_aot_compile", ["0"])
+@pytest.mark.parametrize("use_bytecode_hook", [True, False])
+@pytest.mark.skipif(
+ not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
+)
+def test_dynamic_shapes_compilation(
+ monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
+):
+ """Test that all dynamic shapes types compile successfully"""
+ print(
+ f"\nTesting model: {model_name} with {shapes_type.name}, "
+ f"AOT compile: {use_aot_compile}, "
+ f"Bytecode hook: {use_bytecode_hook}"
+ )
+ if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
+ pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
+
+ monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
+ monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
+
+ prompt = "Hello, my name is"
+
+ print(f"Testing {shapes_type.name} dynamic shapes...")
+
+ # Initialize the model with specific dynamic shapes configuration
+ model = LLM(
+ model=model_name,
+ compilation_config={
+ "mode": CompilationMode.VLLM_COMPILE,
+ "dynamic_shapes_config": {
+ "type": shapes_type.value,
+ },
+ },
+ )
+
+ output = model.generate(prompt)
+ result = output[0].outputs[0].text
+ # Example of setting the sampling parameters
+ tokenizer = get_tokenizer(model_name)
+ yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
+ no_tokens = tokenizer.encode("no", add_special_tokens=False)
+ allowed_ids = list(set(yes_tokens + no_tokens))
+ sampling_params = SamplingParams(
+ max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
+ )
+
+ output = model.generate(
+ "answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
+ sampling_params=sampling_params,
+ )
+ result = output[0].outputs[0].text
+ assert result == "yes"
+
+ # Clean up GPU memory
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ print("GPU memory cleared")
diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py
index 11ae96e930da7..515e0a93ac2a8 100644
--- a/tests/compile/test_functionalization.py
+++ b/tests/compile/test_functionalization.py
@@ -137,7 +137,7 @@ class TestRotaryEmbedding(torch.nn.Module):
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
- base=base,
+ rope_parameters={"rope_type": "default", "rope_theta": base},
)
def forward(self, positions, q, k):
@@ -172,7 +172,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=base,
+ rope_parameters={"rope_type": "default", "rope_theta": base},
)
def forward(self, positions, hidden_states):
diff --git a/tests/config/test_config_utils.py b/tests/config/test_config_utils.py
new file mode 100644
index 0000000000000..1277c7e64eb21
--- /dev/null
+++ b/tests/config/test_config_utils.py
@@ -0,0 +1,166 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass
+from enum import Enum
+
+import pytest
+
+from vllm.config.utils import get_hash_factors, hash_factors, normalize_value
+
+# Helpers
+
+
+def endswith_fqname(obj, suffix: str) -> bool:
+ # normalize_value(type) returns fully-qualified name
+ # Compare suffix to avoid brittle import paths.
+ out = normalize_value(obj)
+ return isinstance(out, str) and out.endswith(suffix)
+
+
+def expected_path(p_str: str = ".") -> str:
+ import pathlib
+
+ p = pathlib.Path(p_str)
+ return p.expanduser().resolve().as_posix()
+
+
+# Minimal dataclass to test get_hash_factors.
+# Avoid importing heavy vLLM configs.
+@dataclass
+class SimpleConfig:
+ a: object
+ b: object | None = None
+
+
+class DummyLogprobsMode(Enum):
+ RAW_LOGITS = "raw_logits"
+
+
+def test_hash_factors_deterministic():
+ """Test that hash_factors produces consistent SHA-256 hashes"""
+ factors = {"a": 1, "b": "test"}
+ hash1 = hash_factors(factors)
+ hash2 = hash_factors(factors)
+
+ assert hash1 == hash2
+ # Dict key insertion order should not affect the hash.
+ factors_reordered = {"b": "test", "a": 1}
+ assert hash_factors(factors_reordered) == hash1
+ assert len(hash1) == 64
+ assert all(c in "0123456789abcdef" for c in hash1)
+
+
+@pytest.mark.parametrize(
+ "inp, expected",
+ [
+ (None, None),
+ (True, True),
+ (1, 1),
+ (1.0, 1.0),
+ ("x", "x"),
+ (b"ab", "6162"),
+ (bytearray(b"ab"), "6162"),
+ ([1, 2], (1, 2)),
+ ({"b": 2, "a": 1}, (("a", 1), ("b", 2))),
+ ],
+)
+def test_normalize_value_matrix(inp, expected):
+ """Parametric input→expected normalization table."""
+ assert normalize_value(inp) == expected
+
+
+def test_normalize_value_enum():
+ # Enums normalize to (module.QualName, value).
+ # DummyLogprobsMode uses a string payload.
+ out = normalize_value(DummyLogprobsMode.RAW_LOGITS)
+ assert isinstance(out, tuple)
+ assert out[0].endswith("DummyLogprobsMode")
+ # Expect string payload 'raw_logits'.
+ assert out[1] == "raw_logits"
+
+
+def test_normalize_value_set_order_insensitive():
+ # Sets are unordered; normalize_value sorts elements for determinism.
+ assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3})
+
+
+def test_normalize_value_path_normalization():
+ from pathlib import Path # local import to avoid global dependency
+
+ # Paths expand/resolve to absolute strings.
+ # Stabilizes hashing across working dirs.
+ assert normalize_value(Path(".")) == expected_path(".")
+
+
+def test_normalize_value_uuid_and_to_json():
+ # Objects may normalize via uuid() or to_json_string().
+ class HasUUID:
+ def uuid(self):
+ return "test-uuid"
+
+ class ToJson:
+ def to_json_string(self):
+ return '{"x":1}'
+
+ assert normalize_value(HasUUID()) == "test-uuid"
+ assert normalize_value(ToJson()) == '{"x":1}'
+
+
+@pytest.mark.parametrize(
+ "bad",
+ [
+ (lambda x: x),
+ (type("CallableInstance", (), {"__call__": lambda self: 0}))(),
+ (lambda: (lambda: 0))(), # nested function instance
+ ],
+)
+def test_error_cases(bad):
+ """Inputs expected to raise TypeError."""
+ # Reject functions/lambdas/callable instances
+ # to avoid under-hashing.
+ with pytest.raises(TypeError):
+ normalize_value(bad)
+
+
+def test_enum_vs_int_disambiguation():
+ # int stays primitive
+ nf_int = normalize_value(1)
+ assert nf_int == 1
+
+ # enum becomes ("module.QualName", value)
+ nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS)
+ assert isinstance(nf_enum, tuple) and len(nf_enum) == 2
+ enum_type, enum_val = nf_enum
+ assert enum_type.endswith(".DummyLogprobsMode")
+ assert enum_val == "raw_logits"
+
+ # Build factor dicts from configs with int vs enum
+ f_int = get_hash_factors(SimpleConfig(1), set())
+ f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set())
+ # The int case remains a primitive value
+ assert f_int["a"] == 1
+ # The enum case becomes a tagged tuple ("module.QualName", "raw_logits")
+ assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits"
+ # Factor dicts must differ so we don't collide primitives with Enums.
+ assert f_int != f_enum
+ # Hash digests must differ correspondingly
+ assert hash_factors(f_int) != hash_factors(f_enum)
+
+ # Hash functions produce stable hex strings
+ h_int = hash_factors(f_int)
+ h_enum = hash_factors(f_enum)
+ assert isinstance(h_int, str) and len(h_int) == 64
+ assert isinstance(h_enum, str) and len(h_enum) == 64
+
+
+def test_classes_are_types():
+ """Types normalize to FQNs; include real vLLM types."""
+ # Only classes allowed; functions/lambdas are rejected.
+ # Canonical form is the fully-qualified name.
+ assert isinstance(normalize_value(str), str)
+
+ class LocalDummy:
+ pass
+
+ assert endswith_fqname(LocalDummy, ".LocalDummy")
diff --git a/tests/conftest.py b/tests/conftest.py
index b17081352edcf..163593eb3f14f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -748,6 +748,14 @@ class VllmRunner:
# being captured which can trigger edge cases that we don't handle yet.
kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
+ # Make sure we have atleast one cudagraph large enough for a single decode.
+ if (speculative_config := kwargs.get("speculative_config")) and (
+ num_speculative_tokens := speculative_config["num_speculative_tokens"]
+ ):
+ kwargs["compilation_config"]["cudagraph_capture_sizes"].append(
+ num_speculative_tokens + 1
+ )
+
with init_ctx:
self.llm = LLM(
model=model_name,
@@ -845,6 +853,7 @@ class VllmRunner:
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: list[RequestOutput],
+ include_prompt_token_ids: bool = False,
) -> list[TokensTextLogprobsPromptLogprobs]:
outputs: list[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
@@ -853,9 +862,26 @@ class VllmRunner:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
- outputs.append(
- (output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
- )
+ if include_prompt_token_ids:
+ outputs.append(
+ ( # type: ignore[arg-type]
+ output_ids,
+ output_str,
+ output_logprobs,
+ req_output.prompt_token_ids,
+ req_output.prompt_logprobs,
+ )
+ )
+ else:
+ outputs.append(
+ (
+ output_ids,
+ output_str,
+ output_logprobs,
+ req_output.prompt_logprobs,
+ )
+ )
+
return outputs
def generate_w_logprobs(
@@ -865,6 +891,7 @@ class VllmRunner:
images: PromptImageInput | None = None,
audios: PromptAudioInput | None = None,
videos: PromptVideoInput | None = None,
+ include_prompt_token_ids: bool = False,
**kwargs: Any,
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
@@ -874,7 +901,7 @@ class VllmRunner:
)
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
- req_outputs
+ req_outputs, include_prompt_token_ids
)
# Omit prompt logprobs if not required by sampling params
return (
diff --git a/tests/distributed/eplb_utils.py b/tests/distributed/eplb_utils.py
new file mode 100644
index 0000000000000..27a63e0215148
--- /dev/null
+++ b/tests/distributed/eplb_utils.py
@@ -0,0 +1,49 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+import random
+
+import torch
+import torch.multiprocessing as mp
+
+from vllm.distributed.parallel_state import (
+ init_distributed_environment,
+)
+from vllm.utils.system_utils import update_environment_variables
+
+mp.set_start_method("spawn", force=True)
+
+
+def distributed_run(fn, world_size, *args):
+ number_of_processes = world_size
+ processes: list[mp.Process] = []
+ for i in range(number_of_processes):
+ env: dict[str, str] = {}
+ env["RANK"] = str(i)
+ env["LOCAL_RANK"] = str(i)
+ env["WORLD_SIZE"] = str(number_of_processes)
+ env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
+ env["MASTER_ADDR"] = "localhost"
+ env["MASTER_PORT"] = "12345"
+ p = mp.Process(target=fn, args=(env, world_size, *args))
+ processes.append(p)
+ p.start()
+
+ for p in processes:
+ p.join()
+
+ for p in processes:
+ assert p.exitcode == 0
+
+
+def set_env_vars_and_device(env: dict[str, str]) -> None:
+ update_environment_variables(env)
+ local_rank = os.environ["LOCAL_RANK"]
+ device = torch.device(f"cuda:{local_rank}")
+ torch.cuda.set_device(device)
+ init_distributed_environment()
+
+ # Ensure each worker process has the same random seed
+ random.seed(42)
+ torch.manual_seed(42)
diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py
index b16fd0d06b145..7e4713b8aece0 100644
--- a/tests/distributed/test_context_parallel.py
+++ b/tests/distributed/test_context_parallel.py
@@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
- dcp_kv_cache_interleave_size: int
+ cp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool
@@ -55,7 +55,7 @@ class CPTestSettings:
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
- dcp_kv_cache_interleave_size: int = 1,
+ cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
@@ -71,7 +71,7 @@ class CPTestSettings:
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
- dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
@@ -116,7 +116,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
- dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
@@ -197,7 +197,7 @@ def _compare_cp_with_tp(
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
- str(dcp_kv_cache_interleave_size),
+ str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
@@ -227,7 +227,7 @@ CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
- CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
+ CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py
index 7b45ae82c72d4..781dfd44c1ef6 100644
--- a/tests/distributed/test_eplb_execute.py
+++ b/tests/distributed/test_eplb_execute.py
@@ -1,63 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import multiprocessing
-import os
+import asyncio
import random
import pytest
import torch
import torch.distributed
-from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
+from vllm.distributed.eplb.rebalance_execute import (
+ move_from_buffer,
+ rearrange_expert_weights_inplace,
+ transfer_layer,
+)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_tp_group,
- init_distributed_environment,
)
-from vllm.utils.system_utils import update_environment_variables
-
-def distributed_run(fn, world_size):
- number_of_processes = world_size
- processes: list[multiprocessing.Process] = []
- for i in range(number_of_processes):
- env: dict[str, str] = {}
- env["RANK"] = str(i)
- env["LOCAL_RANK"] = str(i)
- env["WORLD_SIZE"] = str(number_of_processes)
- env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
- env["MASTER_ADDR"] = "localhost"
- env["MASTER_PORT"] = "12345"
- p = multiprocessing.Process(target=fn, args=(env,))
- processes.append(p)
- p.start()
-
- for p in processes:
- p.join()
-
- for p in processes:
- assert p.exitcode == 0
-
-
-def worker_fn_wrapper(fn):
- # `multiprocessing.Process` cannot accept environment variables directly
- # so we need to pass the environment variables as arguments
- # and update the environment variables in the function
- def wrapped_fn(env):
- update_environment_variables(env)
- local_rank = os.environ["LOCAL_RANK"]
- device = torch.device(f"cuda:{local_rank}")
- torch.cuda.set_device(device)
- init_distributed_environment()
-
- # Ensure each worker process has the same random seed
- random.seed(42)
- torch.manual_seed(42)
-
- fn()
-
- return wrapped_fn
+from .eplb_utils import distributed_run, set_env_vars_and_device
def create_expert_indices_with_redundancy(
@@ -275,6 +236,173 @@ def verify_redundant_experts_have_same_weights(
)
+def _test_async_transfer_layer_without_mtp_worker(
+ env,
+ world_size: int,
+ num_layers: int,
+ num_local_experts: int,
+ num_logical_experts: int,
+) -> None:
+ set_env_vars_and_device(env)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
+
+ tp_group = get_tp_group()
+ ep_group = tp_group.device_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ total_physical_experts = world_size * num_local_experts
+ hidden_sizes = [16, 32]
+
+ redundancy_config = create_redundancy_config(
+ num_logical_experts,
+ total_physical_experts,
+ )
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ redundancy_config,
+ )
+
+ new_redundancy_config = create_redundancy_config(
+ num_logical_experts,
+ total_physical_experts,
+ )
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ new_redundancy_config,
+ )
+
+ expert_weights = create_expert_weights(
+ num_layers,
+ num_local_experts,
+ hidden_sizes,
+ ep_rank,
+ device,
+ old_indices,
+ )
+
+ expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
+ cuda_stream = torch.cuda.Stream(device=device)
+
+ for layer_idx in range(num_layers):
+ is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
+ transfer_layer(
+ old_global_expert_indices=old_indices,
+ new_global_expert_indices=new_indices,
+ expert_weights=expert_weights,
+ expert_weights_buffer=expert_buffer,
+ ep_group=ep_group,
+ layer=layer_idx,
+ cuda_stream=cuda_stream,
+ )
+ )
+
+ cuda_stream.synchronize()
+ move_from_buffer(
+ expert_weights=expert_weights[layer_idx],
+ expert_weights_buffer=expert_buffer,
+ is_unchanged=is_unchanged,
+ is_received_locally=is_received_locally,
+ experts_recv_loc=experts_recv_loc,
+ new_indices=new_indices[layer_idx].tolist(),
+ ep_group=ep_group,
+ )
+
+ verify_expert_weights_after_shuffle(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ ep_rank,
+ num_local_experts,
+ )
+ verify_redundant_experts_have_same_weights(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ world_size,
+ num_local_experts,
+ )
+
+
+def _test_rearrange_expert_weights_with_redundancy(
+ env, world_size, num_layers, num_local_experts, num_logical_experts
+) -> None:
+ # Initialize model parallel (using tensor parallel as an entrypoint
+ # to expert parallel)
+ set_env_vars_and_device(env)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
+
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ # Test parameters
+ total_physical_experts = world_size * num_local_experts
+ hidden_sizes = [32, 64] # Two different weight matrices
+
+ # Create old expert indices (with redundancy)
+ redundancy_config = create_redundancy_config(
+ num_logical_experts, total_physical_experts
+ )
+
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ redundancy_config,
+ )
+
+ # Create new expert indices (with redundancy)
+ new_redundancy_config = create_redundancy_config(
+ num_logical_experts, total_physical_experts
+ )
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ new_redundancy_config,
+ )
+
+ # Create expert weights
+ expert_weights = create_expert_weights(
+ num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
+ )
+
+ # Execute weight rearrangement
+ rearrange_expert_weights_inplace(
+ old_indices,
+ new_indices,
+ expert_weights,
+ ep_group,
+ is_profile=False,
+ )
+
+ # Verify the rearrangement result
+ verify_expert_weights_after_shuffle(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ ep_rank,
+ num_local_experts,
+ )
+
+ verify_redundant_experts_have_same_weights(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ world_size,
+ num_local_experts,
+ )
+
+
@pytest.mark.parametrize(
"world_size,num_layers,num_local_experts,num_logical_experts",
[
@@ -305,78 +433,95 @@ def test_rearrange_expert_weights_with_redundancy(
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
+ distributed_run(
+ _test_rearrange_expert_weights_with_redundancy,
+ world_size,
+ num_layers,
+ num_local_experts,
+ num_logical_experts,
+ )
- @worker_fn_wrapper
- def worker_fn():
- # Initialize model parallel (using tensor parallel as an entrypoint
- # to expert parallel)
- ensure_model_parallel_initialized(
- tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
- )
- ep_group = get_tp_group().cpu_group
- ep_rank = torch.distributed.get_rank()
- device = torch.device(f"cuda:{ep_rank}")
+def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
+ set_env_vars_and_device(env)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
- # Test parameters
- total_physical_experts = world_size * num_local_experts
- hidden_sizes = [32, 64] # Two different weight matrices
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
- # Create old expert indices (with redundancy)
- redundancy_config = create_redundancy_config(
- num_logical_experts, total_physical_experts
- )
+ num_layers = 2
+ num_local_experts = 2
+ total_physical_experts = world_size * num_local_experts
+ num_logical_experts = total_physical_experts // 2 # Some redundancy
+ hidden_sizes = [32, 64]
- old_indices = create_expert_indices_with_redundancy(
- num_layers,
- num_logical_experts,
- total_physical_experts,
- redundancy_config,
- )
+ # Create redundancy configuration
+ redundancy_config = [2] * num_logical_experts
- # Create new expert indices (with redundancy)
- new_redundancy_config = create_redundancy_config(
- num_logical_experts, total_physical_experts
- )
- new_indices = create_expert_indices_with_redundancy(
- num_layers,
- num_logical_experts,
- total_physical_experts,
- new_redundancy_config,
- )
+ # Same indices - no change
+ indices = create_expert_indices_with_redundancy(
+ num_layers, num_logical_experts, total_physical_experts, redundancy_config
+ )
- # Create expert weights
- expert_weights = create_expert_weights(
- num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
- )
+ expert_weights = create_expert_weights(
+ num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
+ )
- # Execute weight rearrangement
- rearrange_expert_weights_inplace(
- old_indices,
- new_indices,
- expert_weights,
- ep_group,
- is_profile=False,
- )
+ # Save original weights
+ original_weights = []
+ for layer_weights in expert_weights:
+ layer_copy = []
+ for weight in layer_weights:
+ layer_copy.append(weight.clone())
+ original_weights.append(layer_copy)
- # Verify the rearrangement result
- verify_expert_weights_after_shuffle(
- expert_weights,
- new_indices,
- hidden_sizes,
- ep_rank,
- num_local_experts,
- )
+ # Execute rearrangement (should be no change)
+ rearrange_expert_weights_inplace(
+ indices,
+ indices, # Same indices
+ expert_weights,
+ ep_group,
+ is_profile=False,
+ )
- verify_redundant_experts_have_same_weights(
- expert_weights,
- new_indices,
- hidden_sizes,
- world_size,
- num_local_experts,
- )
+ # Verify that the weights have not changed
+ for layer in range(num_layers):
+ for weight_idx in range(len(hidden_sizes)):
+ torch.testing.assert_close(
+ expert_weights[layer][weight_idx],
+ original_weights[layer][weight_idx],
+ msg=f"""Layer {layer}, weight {weight_idx}
+ should remain unchanged""",
+ )
- distributed_run(worker_fn, world_size)
+
+@pytest.mark.parametrize(
+ "world_size,num_layers,num_local_experts,num_logical_experts",
+ [
+ (2, 2, 2, 3),
+ ],
+)
+def test_async_transfer_layer_without_mtp(
+ world_size: int,
+ num_layers: int,
+ num_local_experts: int,
+ num_logical_experts: int,
+):
+ """Exercise async EPLB transfer path without MTP/spec decode."""
+
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ distributed_run(
+ _test_async_transfer_layer_without_mtp_worker,
+ world_size,
+ num_layers,
+ num_local_experts,
+ num_logical_experts,
+ )
@pytest.mark.parametrize("world_size", [2, 4])
@@ -388,62 +533,69 @@ def test_rearrange_expert_weights_no_change(world_size):
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
+ distributed_run(_test_rearrange_expert_weights_no_change, world_size)
- @worker_fn_wrapper
- def worker_fn():
- ensure_model_parallel_initialized(
- tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
- )
- ep_group = get_tp_group().cpu_group
- ep_rank = torch.distributed.get_rank()
- device = torch.device(f"cuda:{ep_rank}")
+def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
+ set_env_vars_and_device(env)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
- num_layers = 2
- num_local_experts = 2
- total_physical_experts = world_size * num_local_experts
- num_logical_experts = total_physical_experts // 2 # Some redundancy
- hidden_sizes = [32, 64]
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
- # Create redundancy configuration
- redundancy_config = [2] * num_logical_experts
+ num_layers = 1
+ num_local_experts = 2
+ total_physical_experts = world_size * num_local_experts
+ num_logical_experts = total_physical_experts // 2
+ hidden_sizes = [32]
- # Same indices - no change
- indices = create_expert_indices_with_redundancy(
- num_layers, num_logical_experts, total_physical_experts, redundancy_config
- )
+ # Create different index distributions
+ old_redundancy = create_redundancy_config(
+ num_logical_experts, total_physical_experts
+ )
+ new_redundancy = create_redundancy_config(
+ num_logical_experts, total_physical_experts
+ )
- expert_weights = create_expert_weights(
- num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
- )
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers, num_logical_experts, total_physical_experts, old_redundancy
+ )
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers, num_logical_experts, total_physical_experts, new_redundancy
+ )
- # Save original weights
- original_weights = []
- for layer_weights in expert_weights:
- layer_copy = []
- for weight in layer_weights:
- layer_copy.append(weight.clone())
- original_weights.append(layer_copy)
+ expert_weights = create_expert_weights(
+ num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
+ )
- # Execute rearrangement (should be no change)
- rearrange_expert_weights_inplace(
- indices,
- indices, # Same indices
- expert_weights,
- ep_group,
- is_profile=False,
- )
+ # Save original weights
+ original_weights = []
+ for layer_weights in expert_weights:
+ layer_copy = []
+ for weight in layer_weights:
+ layer_copy.append(weight.clone())
+ original_weights.append(layer_copy)
- # Verify that the weights have not changed
- for layer in range(num_layers):
- for weight_idx in range(len(hidden_sizes)):
- torch.testing.assert_close(
- expert_weights[layer][weight_idx],
- original_weights[layer][weight_idx],
- msg=f"Layer {layer}, weight {weight_idx} should remain unchanged",
- )
+ # Execute profile mode rearrangement
+ rearrange_expert_weights_inplace(
+ old_indices,
+ new_indices,
+ expert_weights,
+ ep_group,
+ is_profile=True, # Profile mode
+ )
- distributed_run(worker_fn, world_size)
+ # In profile mode, the weights should remain unchanged
+ for layer in range(num_layers):
+ for weight_idx in range(len(hidden_sizes)):
+ torch.testing.assert_close(
+ expert_weights[layer][weight_idx],
+ original_weights[layer][weight_idx],
+ msg="In profile mode, the weights should remain unchanged",
+ )
@pytest.mark.parametrize("world_size", [2, 4])
@@ -452,66 +604,4 @@ def test_rearrange_expert_weights_profile_mode(world_size):
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
-
- @worker_fn_wrapper
- def worker_fn():
- ensure_model_parallel_initialized(
- tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
- )
-
- ep_group = get_tp_group().cpu_group
- ep_rank = torch.distributed.get_rank()
- device = torch.device(f"cuda:{ep_rank}")
-
- num_layers = 1
- num_local_experts = 2
- total_physical_experts = world_size * num_local_experts
- num_logical_experts = total_physical_experts // 2
- hidden_sizes = [32]
-
- # Create different index distributions
- old_redundancy = create_redundancy_config(
- num_logical_experts, total_physical_experts
- )
- new_redundancy = create_redundancy_config(
- num_logical_experts, total_physical_experts
- )
-
- old_indices = create_expert_indices_with_redundancy(
- num_layers, num_logical_experts, total_physical_experts, old_redundancy
- )
- new_indices = create_expert_indices_with_redundancy(
- num_layers, num_logical_experts, total_physical_experts, new_redundancy
- )
-
- expert_weights = create_expert_weights(
- num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
- )
-
- # Save original weights
- original_weights = []
- for layer_weights in expert_weights:
- layer_copy = []
- for weight in layer_weights:
- layer_copy.append(weight.clone())
- original_weights.append(layer_copy)
-
- # Execute profile mode rearrangement
- rearrange_expert_weights_inplace(
- old_indices,
- new_indices,
- expert_weights,
- ep_group,
- is_profile=True, # Profile mode
- )
-
- # In profile mode, the weights should remain unchanged
- for layer in range(num_layers):
- for weight_idx in range(len(hidden_sizes)):
- torch.testing.assert_close(
- expert_weights[layer][weight_idx],
- original_weights[layer][weight_idx],
- msg="In profile mode, the weights should remain unchanged",
- )
-
- distributed_run(worker_fn, world_size)
+ distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
diff --git a/tests/distributed/test_eplb_fused_moe_layer.py b/tests/distributed/test_eplb_fused_moe_layer.py
new file mode 100644
index 0000000000000..55f26519887a1
--- /dev/null
+++ b/tests/distributed/test_eplb_fused_moe_layer.py
@@ -0,0 +1,285 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Test that the interaction between EPLB and FusedMoE Layer is okay
+
+from dataclasses import dataclass
+
+import pytest
+import torch
+
+from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
+from vllm.distributed.parallel_state import (
+ ensure_model_parallel_initialized,
+ get_tp_group,
+)
+from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+
+from .eplb_utils import distributed_run, set_env_vars_and_device
+
+
+@dataclass
+class TestConfig:
+ num_layers: int
+ num_experts: int
+ num_local_experts: int
+ num_topk: int
+ hidden_size: int
+ intermediate_size: int
+ weight_dtype: torch.dtype
+ weight_scale_dtype: torch.dtype | None
+ column_major_scales: bool
+
+
+def make_expert_weights(
+ layer_idx: int,
+ global_expert_idx: int,
+ global_num_experts: int,
+ tensor_shape: tuple[int, ...],
+ tensor_dtype: torch.dtype,
+ tensor_device: torch.device,
+ is_column_major: bool,
+) -> torch.Tensor:
+ assert len(tensor_shape) == 2
+
+ if is_column_major:
+ tensor_shape = (tensor_shape[1], tensor_shape[0])
+
+ x = torch.empty(tensor_shape, dtype=tensor_dtype, device=tensor_device)
+ value_offset = (layer_idx * global_num_experts + global_expert_idx) * x.numel()
+ x.view(-1).copy_(
+ torch.arange(
+ value_offset,
+ value_offset + x.numel(),
+ dtype=tensor_dtype,
+ device=tensor_device,
+ )
+ )
+
+ if is_column_major:
+ x = torch.transpose(x, 1, 0)
+ assert not x.is_contiguous()
+ return x
+
+
+def make_fused_moe_layer(
+ rank: int,
+ layer_idx: int,
+ test_config: TestConfig,
+) -> FusedMoE:
+ fml = FusedMoE(
+ num_experts=test_config.num_experts,
+ top_k=test_config.num_topk,
+ hidden_size=test_config.hidden_size,
+ intermediate_size=test_config.intermediate_size,
+ prefix=f"dummy_layer_{layer_idx}",
+ activation="silu",
+ is_act_and_mul=True,
+ params_dtype=test_config.weight_dtype,
+ )
+
+ device = torch.device(f"cuda:{rank}")
+
+ from functools import partial
+
+ _make_expert_weights = partial(
+ make_expert_weights,
+ layer_idx=layer_idx,
+ global_num_experts=test_config.num_experts,
+ tensor_device=device,
+ )
+
+ assert isinstance(fml.w13_weight.data, torch.Tensor)
+ assert isinstance(fml.w2_weight.data, torch.Tensor)
+ fml.w13_weight.data = fml.w13_weight.data.to(device=device)
+ fml.w2_weight.data = fml.w2_weight.data.to(device=device)
+ w13_weight = fml.w13_weight.data
+ w2_weight = fml.w2_weight.data
+ assert w13_weight.size(0) == test_config.num_local_experts
+ for i in range(test_config.num_local_experts):
+ g_i = rank * test_config.num_local_experts + i
+ w13_weight_e = w13_weight[i]
+ w2_weight_e = w2_weight[i]
+ w13_weight_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w13_weight_e.shape,
+ tensor_dtype=w13_weight_e.dtype,
+ is_column_major=False,
+ )
+ )
+ w2_weight_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w2_weight_e.shape,
+ tensor_dtype=w2_weight_e.dtype,
+ is_column_major=False,
+ )
+ )
+
+ block_size = 16
+
+ def block_quant_scales_shape(
+ shape: tuple[int, ...], is_column_major: bool
+ ) -> tuple[int, ...]:
+ assert len(shape) == 3
+ if not is_column_major:
+ return (shape[0], shape[1] // block_size, shape[2] // block_size)
+ else:
+ return (shape[0], shape[2] // block_size, shape[1] // block_size)
+
+ is_column_major = test_config.column_major_scales
+ w13_weight_scale_inv = torch.empty(
+ block_quant_scales_shape(w13_weight.shape, is_column_major),
+ dtype=test_config.weight_dtype,
+ device=device,
+ )
+ w2_weight_scale_inv = torch.empty(
+ block_quant_scales_shape(w2_weight.shape, is_column_major),
+ dtype=test_config.weight_dtype,
+ device=device,
+ )
+
+ for i in range(test_config.num_local_experts):
+ g_i = rank * test_config.num_local_experts + i
+ w13_s_e = w13_weight_scale_inv[i]
+ w2_s_e = w2_weight_scale_inv[i]
+ w13_s_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w13_s_e.shape,
+ tensor_dtype=w13_s_e.dtype,
+ # Fill data in row-major and then
+ # transpose if test_config requires col-major.
+ is_column_major=False,
+ )
+ )
+ w2_s_e.copy_(
+ _make_expert_weights(
+ global_expert_idx=g_i,
+ tensor_shape=w2_s_e.shape,
+ tensor_dtype=w2_s_e.dtype,
+ is_column_major=False,
+ )
+ )
+ if is_column_major:
+ w13_weight_scale_inv = torch.transpose(w13_weight_scale_inv, 1, 2)
+ w2_weight_scale_inv = torch.transpose(w2_weight_scale_inv, 1, 2)
+ assert not w13_weight_scale_inv.is_contiguous()
+ assert not w2_weight_scale_inv.is_contiguous()
+
+ # Add scales to the parameter list
+ fml.w13_weight_scale_inv = torch.nn.Parameter(
+ w13_weight_scale_inv, requires_grad=False
+ )
+ fml.w2_weight_scale_inv = torch.nn.Parameter(
+ w2_weight_scale_inv, requires_grad=False
+ )
+
+ return fml
+
+
+def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
+ # Initialize model parallel (using tensor parallel as an entrypoint
+ # to expert parallel)
+ set_env_vars_and_device(env)
+
+ vllm_config = VllmConfig()
+ vllm_config.parallel_config.tensor_parallel_size = world_size
+ vllm_config.parallel_config.enable_expert_parallel = True
+
+ with set_current_vllm_config(vllm_config):
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
+ )
+
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+
+ fml_layers = [
+ make_fused_moe_layer(ep_rank, layer_idx, test_config)
+ for layer_idx in range(test_config.num_layers)
+ ]
+ rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
+
+ indices = torch.zeros(
+ test_config.num_layers, test_config.num_experts, dtype=torch.long
+ )
+ for lidx in range(test_config.num_layers):
+ indices[lidx] = torch.Tensor(range(test_config.num_experts))
+
+ shuffled_indices = torch.zeros_like(indices)
+ for lidx in range(test_config.num_layers):
+ shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
+
+ rearrange_expert_weights_inplace(
+ indices,
+ shuffled_indices,
+ rank_expert_weights,
+ ep_group,
+ is_profile=False,
+ )
+
+ num_local_experts = test_config.num_local_experts
+ num_global_experts = test_config.num_experts
+ for lidx, fml in enumerate(fml_layers):
+ for name, w in fml.named_parameters():
+ for e in range(num_local_experts):
+ g_e = shuffled_indices[lidx][ep_rank * num_local_experts + e]
+ ref = make_expert_weights(
+ layer_idx=lidx,
+ global_expert_idx=int(g_e.item()),
+ global_num_experts=num_global_experts,
+ tensor_shape=w[e].shape,
+ tensor_dtype=w[e].dtype,
+ tensor_device=w[e].device,
+ is_column_major=not w[e].is_contiguous(),
+ )
+ assert w[e].shape == ref.shape and w[e].stride() == ref.stride(), (
+ f"w[{e}] {w[e].size()} {w[e].stride()} vs "
+ f"ref {ref.size()} {ref.stride()}"
+ )
+ torch.testing.assert_close(w[e], ref)
+
+
+@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("num_layers", [4])
+@pytest.mark.parametrize("num_experts", [16])
+@pytest.mark.parametrize("hidden_size", [256])
+@pytest.mark.parametrize("intermediate_size", [256])
+@pytest.mark.parametrize("column_major_scales", [True, False])
+def test_eplb_fml(
+ world_size: int,
+ num_layers: int,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ column_major_scales: bool,
+):
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ num_local_experts = num_experts // world_size
+ num_topk = 4
+ # The dtypes are fine as we are essentially just checking data-copies
+ weight_dtype = torch.bfloat16
+ weight_scale_dtype = torch.bfloat16
+
+ test_config = TestConfig(
+ num_layers=num_layers,
+ num_experts=num_experts,
+ num_local_experts=num_local_experts,
+ num_topk=num_topk,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ weight_dtype=weight_dtype,
+ weight_scale_dtype=weight_scale_dtype,
+ column_major_scales=column_major_scales,
+ )
+
+ distributed_run(
+ _test_eplb_fml,
+ world_size,
+ test_config,
+ )
diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py
index 11e23f128f331..c055b7a3f6dd7 100644
--- a/tests/distributed/test_eplb_spec_decode.py
+++ b/tests/distributed/test_eplb_spec_decode.py
@@ -10,10 +10,11 @@ from tests.utils import large_gpu_mark
def get_model_args(
model_name: str,
- spec_model_name: str,
+ spec_model_name: str | None,
spec_method: str,
tp_size: int,
model_max_len: int,
+ use_async: bool = False,
) -> dict:
speculative_config = {
"method": spec_method,
@@ -37,6 +38,8 @@ def get_model_args(
"enable_eplb": True,
"max_model_len": model_max_len,
}
+ if use_async:
+ model_args["eplb_config"] = {"use_async": True}
return model_args
@@ -94,3 +97,37 @@ def test_eplb_spec_decode(
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
+
+
+@large_gpu_mark(min_gb=80)
+def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
+ """
+ Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
+ """
+
+ TASK = "gsm8k"
+ FILTER = "exact_match,strict-match"
+ RTOL = 0.03
+ expected_gsm8k_value = 0.86
+
+ model_args = get_model_args(
+ model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
+ spec_model_name=None,
+ spec_method="mtp",
+ tp_size=4,
+ model_max_len=4096,
+ use_async=True,
+ )
+
+ results = lm_eval.simple_evaluate(
+ model="vllm",
+ model_args=model_args,
+ tasks=TASK,
+ batch_size=64,
+ num_fewshot=8,
+ )
+ measured_value = results["results"][TASK][FILTER]
+ assert (
+ measured_value - RTOL < expected_gsm8k_value
+ and measured_value + RTOL > expected_gsm8k_value
+ ), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py
index 0ab94d30858fb..89f035d2cdd6f 100644
--- a/tests/distributed/test_pipeline_parallel.py
+++ b/tests/distributed/test_pipeline_parallel.py
@@ -130,6 +130,7 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"pfnet/plamo-2-1b": PPTestSettings.fast(),
+ "pfnet/plamo-3-nict-2b-base": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersForCausalLM
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py
index c3085beeb3564..c7c9d0602def0 100644
--- a/tests/distributed/test_pynccl.py
+++ b/tests/distributed/test_pynccl.py
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import multiprocessing
import os
+import multiprocess as mp
import numpy as np
import pytest
import torch
@@ -20,10 +20,12 @@ from vllm.distributed.parallel_state import (
)
from vllm.utils.system_utils import update_environment_variables
+mp.set_start_method("spawn", force=True)
+
def distributed_run(fn, world_size):
number_of_processes = world_size
- processes: list[multiprocessing.Process] = []
+ processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
@@ -32,7 +34,7 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
- p = multiprocessing.Process(target=fn, args=(env,))
+ p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index 472b1487ef440..10827e3b4b9cd 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -279,7 +279,7 @@ def test_prefix_cache_default():
args = parser.parse_args([])
engine_args = EngineArgs.from_cli_args(args=args)
- assert not engine_args.enable_prefix_caching, "prefix caching defaults to off."
+ assert engine_args.enable_prefix_caching, "prefix caching should default to on."
# with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"])
diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py
index 4e7b765d7713f..65a6fd20bd0d1 100644
--- a/tests/entrypoints/openai/test_metrics.py
+++ b/tests/entrypoints/openai/test_metrics.py
@@ -183,9 +183,6 @@ async def test_metrics_counts(
EXPECTED_METRICS_V1 = [
"vllm:num_requests_running",
"vllm:num_requests_waiting",
- "vllm:gpu_cache_usage_perc",
- "vllm:gpu_prefix_cache_queries",
- "vllm:gpu_prefix_cache_hits",
"vllm:kv_cache_usage_perc",
"vllm:prefix_cache_queries",
"vllm:prefix_cache_hits",
diff --git a/tests/entrypoints/openai/test_response_api_simple.py b/tests/entrypoints/openai/test_response_api_simple.py
new file mode 100644
index 0000000000000..425b8199a0fd0
--- /dev/null
+++ b/tests/entrypoints/openai/test_response_api_simple.py
@@ -0,0 +1,71 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+import pytest
+import pytest_asyncio
+from openai import OpenAI
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "Qwen/Qwen3-8B"
+
+
+@pytest.fixture(scope="module")
+def server():
+ args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
+ env_dict = dict(
+ VLLM_ENABLE_RESPONSES_API_STORE="1",
+ # uncomment for tool calling
+ # PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
+ )
+
+ with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
+ yield remote_server
+
+
+@pytest_asyncio.fixture
+async def client(server):
+ async with server.get_async_client() as async_client:
+ yield async_client
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_basic(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ )
+ assert response is not None
+ print("response: ", response)
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_reasoning_item(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {"type": "message", "content": "Hello.", "role": "user"},
+ {
+ "type": "reasoning",
+ "id": "lol",
+ "content": [
+ {
+ "type": "reasoning_text",
+ "text": "We need to respond: greeting.",
+ }
+ ],
+ "summary": [],
+ },
+ ],
+ temperature=0.0,
+ )
+ assert response is not None
+ assert response.status == "completed"
+ # make sure we get a reasoning and text output
+ assert response.output[0].type == "reasoning"
+ assert response.output[1].type == "message"
+ assert type(response.output[1].content[0].text) is str
diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py
index dea8d2d28f61a..8fd3545eccffa 100644
--- a/tests/entrypoints/openai/test_response_api_with_harmony.py
+++ b/tests/entrypoints/openai/test_response_api_with_harmony.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
+import importlib
import json
import time
@@ -35,7 +35,11 @@ GET_WEATHER_SCHEMA = {
@pytest.fixture(scope="module")
def server():
- args = ["--enforce-eager", "--tool-server", "demo"]
+ assert importlib.util.find_spec("gpt_oss") is not None, (
+ "Harmony tests require gpt_oss package to be installed"
+ )
+
+ args = ["--enforce-eager", "--tool-server", "demo", "--max_model_len", "5000"]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
@@ -550,6 +554,31 @@ def call_function(name, args):
raise ValueError(f"Unknown function: {name}")
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_reasoning_item(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {"type": "message", "content": "Hello.", "role": "user"},
+ {
+ "type": "reasoning",
+ "id": "lol",
+ "content": [
+ {
+ "type": "reasoning_text",
+ "text": "We need to respond: greeting.",
+ }
+ ],
+ "summary": [],
+ },
+ ],
+ temperature=0.0,
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling(client: OpenAI, model_name: str):
diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
index 2b68a653f4600..37e52d2cdf609 100644
--- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
+++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from unittest.mock import MagicMock, patch
+
import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
@@ -132,3 +134,129 @@ def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
+
+
+def test_extract_tool_calls_deeply_nested_json(parser):
+ # Test with deeply nested JSON parameters (5 levels)
+ model_output = (
+ '{"name": "complexTool", '
+ '"parameters": {'
+ '"level1": {'
+ '"level2": {'
+ '"level3": {'
+ '"level4": {'
+ '"value": "deep"'
+ "}}}}}}"
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "complexTool"
+ # Verify the nested structure is preserved in the arguments
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
+
+
+def test_extract_tool_calls_multiple_with_deep_nesting(parser):
+ # Test with multiple tool calls where some have deeply nested parameters
+ model_output = (
+ '{"name": "simpleTool", "parameters": {"value": "test"}}; '
+ '{"name": "complexTool", "parameters": '
+ '{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 2
+
+ # Check first tool call
+ assert result.tool_calls[0].function.name == "simpleTool"
+ import json
+
+ args0 = json.loads(result.tool_calls[0].function.arguments)
+ assert args0["value"] == "test"
+
+ # Check second tool call with deep nesting
+ assert result.tool_calls[1].function.name == "complexTool"
+ args1 = json.loads(result.tool_calls[1].function.arguments)
+ assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
+
+
+def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
+ # Test with quotes and brackets inside quoted string values
+ model_output = (
+ '{"name": "searchTool", '
+ '"parameters": {'
+ '"query": "test {value} [complex]",'
+ '"nested": {"inner": "more {brackets}"}'
+ "}}"
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "searchTool"
+ # Verify the string values are preserved including brackets and quotes
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["query"] == "test {value} [complex]"
+ assert args["nested"]["inner"] == "more {brackets}"
+
+
+def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
+ # Test with escaped quotes in deeply nested JSON
+ model_output = (
+ '{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
+ )
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "parserTool"
+ # Verify escaped quotes are preserved
+ import json
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["text"] == 'He said "Hello {world}"'
+
+
+def test_extract_tool_calls_missing_name_key(parser):
+ # Test that missing "name" key returns content
+ model_output = '{"parameters": {}}'
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ assert result.content == model_output
+
+
+def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
+ # Test that missing both "parameters" and "arguments" keys returns content
+ model_output = '{"name": "toolWithoutParams"}'
+ result = parser.extract_tool_calls(model_output, None)
+
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ assert result.content == model_output
+
+
+def test_regex_timeout_handling(parser):
+ """Test regex timeout is handled gracefully"""
+ fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
+
+ # create a mock regex that raises TimeoutError
+ mock_regex = MagicMock()
+ mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
+
+ with patch.object(parser, "tool_call_start_regex", mock_regex):
+ result = parser.extract_tool_calls(fake_problematic_input, None)
+
+ # should treat as regular text when regex times out
+ assert result.content == fake_problematic_input
+ assert result.tools_called is False
+ assert len(result.tool_calls) == 0
+ mock_regex.finditer.assert_called_once()
diff --git a/tests/entrypoints/pooling/llm/__init__.py b/tests/entrypoints/pooling/basic/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/__init__.py
rename to tests/entrypoints/pooling/basic/__init__.py
diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/basic/test_encode.py
similarity index 92%
rename from tests/entrypoints/pooling/llm/test_encode.py
rename to tests/entrypoints/pooling/basic/test_encode.py
index ca85d2758fce4..f86ecef2e4744 100644
--- a/tests/entrypoints/pooling/llm/test_encode.py
+++ b/tests/entrypoints/pooling/basic/test_encode.py
@@ -7,6 +7,12 @@ import pytest
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "intfloat/multilingual-e5-small"
diff --git a/tests/entrypoints/pooling/openai/test_truncation.py b/tests/entrypoints/pooling/basic/test_truncation.py
similarity index 95%
rename from tests/entrypoints/pooling/openai/test_truncation.py
rename to tests/entrypoints/pooling/basic/test_truncation.py
index 6889628dc9145..0d2d385840402 100644
--- a/tests/entrypoints/pooling/openai/test_truncation.py
+++ b/tests/entrypoints/pooling/basic/test_truncation.py
@@ -7,6 +7,12 @@ import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128
diff --git a/tests/entrypoints/pooling/openai/__init__.py b/tests/entrypoints/pooling/classify/__init__.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/__init__.py
rename to tests/entrypoints/pooling/classify/__init__.py
diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/classify/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_classify.py
rename to tests/entrypoints/pooling/classify/test_offline.py
diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/classify/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_classification.py
rename to tests/entrypoints/pooling/classify/test_online.py
diff --git a/tests/entrypoints/pooling/openai/test_vision_classification.py b/tests/entrypoints/pooling/classify/test_online_vision.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_vision_classification.py
rename to tests/entrypoints/pooling/classify/test_online_vision.py
diff --git a/tests/entrypoints/pooling/embed/__init__.py b/tests/entrypoints/pooling/embed/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/embed/test_correctness_mteb.py
similarity index 87%
rename from tests/entrypoints/pooling/correctness/test_mteb_embed.py
rename to tests/entrypoints/pooling/embed/test_correctness_mteb.py
index 7f16638e51e2c..64673534fd32a 100644
--- a/tests/entrypoints/pooling/correctness/test_mteb_embed.py
+++ b/tests/entrypoints/pooling/embed/test_correctness_mteb.py
@@ -11,6 +11,12 @@ from tests.models.language.pooling_mteb_test.mteb_utils import (
run_mteb_embed_task,
)
from tests.utils import RemoteOpenAIServer
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/embed/test_offline.py
similarity index 90%
rename from tests/entrypoints/pooling/llm/test_embedding.py
rename to tests/entrypoints/pooling/embed/test_offline.py
index 5455b5f91fc09..f5eab4c29ae18 100644
--- a/tests/entrypoints/pooling/llm/test_embedding.py
+++ b/tests/entrypoints/pooling/embed/test_offline.py
@@ -9,6 +9,12 @@ import torch.nn.functional as F
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "intfloat/multilingual-e5-small"
diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/embed/test_online.py
similarity index 99%
rename from tests/entrypoints/pooling/openai/test_embedding.py
rename to tests/entrypoints/pooling/embed/test_online.py
index e971b23e8f1a0..0c88d800e2f99 100644
--- a/tests/entrypoints/pooling/openai/test_embedding.py
+++ b/tests/entrypoints/pooling/embed/test_online.py
@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse,
PoolingResponse,
)
+from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
@@ -28,6 +29,11 @@ from vllm.utils.serial_utils import (
decode_pooling_output,
)
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
+
MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
DTYPE = "bfloat16"
diff --git a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/embed/test_online_dimensions.py
similarity index 95%
rename from tests/entrypoints/pooling/openai/test_embedding_dimensions.py
rename to tests/entrypoints/pooling/embed/test_online_dimensions.py
index ba9fb64262772..8018dac2d3ffe 100644
--- a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py
+++ b/tests/entrypoints/pooling/embed/test_online_dimensions.py
@@ -12,6 +12,12 @@ from tests.models.language.pooling.embed_utils import run_embedding_correctness_
from tests.models.utils import EmbedModelInfo
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODELS = [
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
diff --git a/tests/entrypoints/pooling/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/embed/test_online_long_text.py
similarity index 98%
rename from tests/entrypoints/pooling/openai/test_embedding_long_text.py
rename to tests/entrypoints/pooling/embed/test_online_long_text.py
index f977c81a9084e..a9ade09dad0b5 100644
--- a/tests/entrypoints/pooling/openai/test_embedding_long_text.py
+++ b/tests/entrypoints/pooling/embed/test_online_long_text.py
@@ -16,6 +16,12 @@ import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
def _generate_random_text(word_count: int) -> str:
diff --git a/tests/entrypoints/pooling/openai/test_vision_embedding.py b/tests/entrypoints/pooling/embed/test_online_vision.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_vision_embedding.py
rename to tests/entrypoints/pooling/embed/test_online_vision.py
diff --git a/tests/entrypoints/pooling/pooling/__init__.py b/tests/entrypoints/pooling/pooling/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/openai/test_pooling.py b/tests/entrypoints/pooling/pooling/test_online.py
similarity index 100%
rename from tests/entrypoints/pooling/openai/test_pooling.py
rename to tests/entrypoints/pooling/pooling/test_online.py
diff --git a/tests/entrypoints/pooling/reward/__init__.py b/tests/entrypoints/pooling/reward/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/reward/test_offline.py
similarity index 100%
rename from tests/entrypoints/pooling/llm/test_reward.py
rename to tests/entrypoints/pooling/reward/test_offline.py
diff --git a/tests/entrypoints/pooling/score/__init__.py b/tests/entrypoints/pooling/score/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/entrypoints/pooling/correctness/test_mteb_score.py b/tests/entrypoints/pooling/score/test_correctness_mteb.py
similarity index 91%
rename from tests/entrypoints/pooling/correctness/test_mteb_score.py
rename to tests/entrypoints/pooling/score/test_correctness_mteb.py
index 1afe68b189db8..81ad0097187b0 100644
--- a/tests/entrypoints/pooling/correctness/test_mteb_score.py
+++ b/tests/entrypoints/pooling/score/test_correctness_mteb.py
@@ -13,6 +13,12 @@ from tests.models.language.pooling_mteb_test.mteb_utils import (
run_mteb_rerank,
)
from tests.utils import RemoteOpenAIServer
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/score/test_offline.py
similarity index 90%
rename from tests/entrypoints/pooling/llm/test_score.py
rename to tests/entrypoints/pooling/score/test_offline.py
index b69c6a47c1913..ce36d61cb8476 100644
--- a/tests/entrypoints/pooling/llm/test_score.py
+++ b/tests/entrypoints/pooling/score/test_offline.py
@@ -9,6 +9,12 @@ import torch
from tests.models.utils import softmax
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/score/test_online_rerank.py
similarity index 97%
rename from tests/entrypoints/pooling/openai/test_rerank.py
rename to tests/entrypoints/pooling/score/test_online_rerank.py
index 1d85190c12a19..5a772e22a7414 100644
--- a/tests/entrypoints/pooling/openai/test_rerank.py
+++ b/tests/entrypoints/pooling/score/test_online_rerank.py
@@ -8,6 +8,12 @@ import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
diff --git a/tests/entrypoints/pooling/openai/test_score.py b/tests/entrypoints/pooling/score/test_online_score.py
similarity index 97%
rename from tests/entrypoints/pooling/openai/test_score.py
rename to tests/entrypoints/pooling/score/test_online_score.py
index b8f796d47efaa..ceff9d0181825 100644
--- a/tests/entrypoints/pooling/openai/test_score.py
+++ b/tests/entrypoints/pooling/score/test_online_score.py
@@ -10,6 +10,12 @@ from torch import tensor
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ScoreResponse
+from vllm.platforms import current_platform
+
+if current_platform.is_rocm():
+ pytest.skip(
+ "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
+ )
MODELS = [
{"name": "BAAI/bge-reranker-v2-m3", "is_cross_encoder": True},
diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py
index ca87b3e76b3f4..7baf564ad01a4 100644
--- a/tests/entrypoints/test_chat_utils.py
+++ b/tests/entrypoints/test_chat_utils.py
@@ -103,6 +103,19 @@ def qwen2_audio_model_config():
)
+@pytest.fixture(scope="function")
+def audio_embeds_model_config():
+ return ModelConfig(
+ QWEN2AUDIO_MODEL_ID,
+ runner="generate",
+ trust_remote_code=True,
+ limit_mm_per_prompt={
+ "audio": 2,
+ },
+ enable_mm_embeds=True,
+ )
+
+
@pytest.fixture(scope="module")
def qwen2_audio_tokenizer():
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
@@ -843,6 +856,138 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
+def test_parse_chat_messages_empty_audio_embeds_with_uuid(
+ audio_embeds_model_config,
+ qwen2_audio_tokenizer,
+):
+ """Test audio_embeds with UUID (no actual embeds data)."""
+ uuid = "test-audio-uuid-123"
+
+ conversation, mm_data, mm_uuids = parse_chat_messages(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this audio"},
+ {"type": "audio_embeds", "audio_embeds": None, "uuid": uuid},
+ ],
+ }
+ ],
+ audio_embeds_model_config,
+ qwen2_audio_tokenizer,
+ content_format="string",
+ )
+
+ # Should have audio in mm_data as None (UUID provided)
+ assert mm_data is not None
+ assert "audio" in mm_data
+ assert mm_data["audio"] is None
+ # UUID should be recorded
+ assert mm_uuids is not None
+ assert "audio" in mm_uuids
+ _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
+
+
+def test_parse_chat_messages_audio_embeds_with_string(
+ audio_embeds_model_config,
+ qwen2_audio_tokenizer,
+):
+ """Test audio_embeds with base64 string embedding data."""
+ import base64
+ import io
+
+ import torch
+
+ # Create a sample audio embedding tensor
+ audio_embedding = torch.randn(1, 128, 768)
+
+ # Encode it as base64
+ buffer = io.BytesIO()
+ torch.save(audio_embedding, buffer)
+ buffer.seek(0)
+ binary_data = buffer.read()
+ base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
+
+ conversation, mm_data, mm_uuids = parse_chat_messages(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this audio"},
+ {
+ "type": "audio_embeds",
+ "audio_embeds": base64_audio_embedding,
+ },
+ ],
+ }
+ ],
+ audio_embeds_model_config,
+ qwen2_audio_tokenizer,
+ content_format="string",
+ )
+
+ # Should have audio embedding in mm_data (single tensor, not a list)
+ assert mm_data is not None
+ assert "audio" in mm_data
+ assert isinstance(mm_data["audio"], torch.Tensor)
+ assert mm_data["audio"].shape == audio_embedding.shape
+ # No UUID provided
+ assert mm_uuids is not None
+ assert "audio" in mm_uuids
+ _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_audio_embeds_async(
+ audio_embeds_model_config,
+ qwen2_audio_tokenizer,
+):
+ """Test audio_embeds with async futures."""
+ import base64
+ import io
+
+ import torch
+
+ # Create a sample audio embedding tensor
+ audio_embedding = torch.randn(1, 128, 768)
+
+ # Encode it as base64
+ buffer = io.BytesIO()
+ torch.save(audio_embedding, buffer)
+ buffer.seek(0)
+ binary_data = buffer.read()
+ base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
+
+ conversation, mm_future, mm_uuids = parse_chat_messages_futures(
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this audio"},
+ {
+ "type": "audio_embeds",
+ "audio_embeds": base64_audio_embedding,
+ },
+ ],
+ }
+ ],
+ audio_embeds_model_config,
+ qwen2_audio_tokenizer,
+ content_format="string",
+ )
+
+ # Should have audio embedding in mm_data (single tensor, not a list)
+ mm_data = await mm_future
+ assert mm_data is not None
+ assert "audio" in mm_data
+ assert isinstance(mm_data["audio"], torch.Tensor)
+ assert mm_data["audio"].shape == audio_embedding.shape
+ # No UUID provided
+ assert mm_uuids is not None
+ assert "audio" in mm_uuids
+ _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
+
+
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds,
diff --git a/tests/entrypoints/test_responses_utils.py b/tests/entrypoints/test_responses_utils.py
index 48bf06088bc05..91c818374e3fd 100644
--- a/tests/entrypoints/test_responses_utils.py
+++ b/tests/entrypoints/test_responses_utils.py
@@ -1,7 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+from openai.types.responses.response_reasoning_item import (
+ Content,
+ ResponseReasoningItem,
+ Summary,
+)
+
from vllm.entrypoints.responses_utils import (
+ construct_chat_message_with_tool_call,
convert_tool_responses_to_completions_format,
)
@@ -28,3 +36,53 @@ class TestResponsesUtils:
result = convert_tool_responses_to_completions_format(input_tool)
assert result == {"type": "function", "function": input_tool}
+
+ def test_construct_chat_message_with_tool_call(self):
+ item = ResponseReasoningItem(
+ id="lol",
+ summary=[],
+ type="reasoning",
+ content=[
+ Content(
+ text="Leroy Jenkins",
+ type="reasoning_text",
+ )
+ ],
+ encrypted_content=None,
+ status=None,
+ )
+ formatted_item = construct_chat_message_with_tool_call(item)
+ assert formatted_item["role"] == "assistant"
+ assert formatted_item["reasoning"] == "Leroy Jenkins"
+
+ item = ResponseReasoningItem(
+ id="lol",
+ summary=[
+ Summary(
+ text='Hmm, the user has just started with a simple "Hello,"',
+ type="summary_text",
+ )
+ ],
+ type="reasoning",
+ content=None,
+ encrypted_content=None,
+ status=None,
+ )
+
+ formatted_item = construct_chat_message_with_tool_call(item)
+ assert formatted_item["role"] == "assistant"
+ assert (
+ formatted_item["reasoning"]
+ == 'Hmm, the user has just started with a simple "Hello,"'
+ )
+
+ item = ResponseReasoningItem(
+ id="lol",
+ summary=[],
+ type="reasoning",
+ content=None,
+ encrypted_content="TOP_SECRET_MESSAGE",
+ status=None,
+ )
+ with pytest.raises(ValueError):
+ construct_chat_message_with_tool_call(item)
diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py
index 1dec46e33f22e..8f58c470d217a 100644
--- a/tests/kernels/attention/test_aiter_flash_attn.py
+++ b/tests/kernels/attention/test_aiter_flash_attn.py
@@ -6,6 +6,7 @@ import pytest
import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
+from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2)]
@@ -100,6 +101,8 @@ def test_varlen_with_paged_kv(
num_blocks: int,
q_dtype: torch.dtype | None,
) -> None:
+ if not is_flash_attn_varlen_func_available():
+ pytest.skip("flash_attn_varlen_func required to run this test.")
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py
index 9662e73321ebe..1a7d5ce0ddc1e 100644
--- a/tests/kernels/attention/test_attention.py
+++ b/tests/kernels/attention/test_attention.py
@@ -13,12 +13,6 @@ from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes
-if not current_platform.is_rocm():
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
-
- from tests.kernels.utils import make_alibi_bias
-
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
@@ -448,129 +442,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
-@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("head_size", HEAD_SIZES)
-@pytest.mark.parametrize("dtype", DTYPES)
-@pytest.mark.parametrize("seed", SEEDS)
-@pytest.mark.parametrize("device", CUDA_DEVICES)
-@pytest.mark.skipif(
- current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
-)
-@torch.inference_mode()
-def test_multi_query_kv_attention(
- num_seqs: int,
- num_heads: tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: str,
- use_alibi: bool = False,
-) -> None:
- current_platform.seed_everything(seed)
- torch.set_default_device(device)
- # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
- # As the xformers library is already tested with its own tests, we can use
- # a smaller MAX_SEQ_LEN here.
- max_len = min(MAX_SEQ_LEN, 4096)
- seq_lens = random.sample(range(1, max_len), num_seqs)
- num_tokens = sum(seq_lens)
-
- scale = float(1.0 / (head_size**0.5))
- num_query_heads, num_kv_heads = num_heads
- qkv = torch.empty(
- num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
- )
- qkv.uniform_(-scale, scale)
- query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
-
- num_queries_per_kv = num_query_heads // num_kv_heads
- if num_queries_per_kv > 1:
- # Handle MQA and GQA
- key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
- value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
- alibi_bias = None
- if use_alibi:
- alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
- attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
- output = torch.empty_like(query)
- start = 0
- # Dynamic sequence length not supported with custom attn_bias.
- for i, seq_len in enumerate(seq_lens):
- end = start + seq_len
- out = xops.memory_efficient_attention_forward(
- query[None, start:end],
- key[None, start:end],
- value[None, start:end],
- attn_bias=attn_bias[i],
- p=0.0,
- scale=scale,
- )
- output[start:end].copy_(out.view_as(query[start:end]))
- start += seq_len
- # xformers.AttentionBias to Tensor for use in reference impl.
- alibi_bias = [
- b.materialize((1, num_query_heads, i, i), device=device).squeeze()
- for b, i in zip(attn_bias, seq_lens)
- ]
- else:
- attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
- output = xops.memory_efficient_attention_forward(
- query.unsqueeze(0),
- key.unsqueeze(0),
- value.unsqueeze(0),
- attn_bias=attn_bias,
- p=0.0,
- scale=scale,
- )
- output = output.squeeze(0)
-
- cu_seq_lens = [0]
- for seq_len in seq_lens:
- cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
- ref_output = ref_multi_query_kv_attention(
- cu_seq_lens,
- query,
- key,
- value,
- scale,
- alibi_bias,
- dtype,
- )
- atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
- rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
- torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
-
-
-@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
-@pytest.mark.parametrize("num_heads", NUM_HEADS)
-@pytest.mark.parametrize("head_size", [64])
-@pytest.mark.parametrize("dtype", DTYPES)
-@pytest.mark.parametrize("seed", SEEDS)
-@pytest.mark.parametrize("device", CUDA_DEVICES)
-@pytest.mark.skipif(
- current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
-)
-@torch.inference_mode()
-def test_multi_query_kv_attention_with_alibi(
- num_seqs: int,
- num_heads: tuple[int, int],
- head_size: int,
- dtype: torch.dtype,
- seed: int,
- device: str,
-) -> None:
- return test_multi_query_kv_attention(
- num_seqs,
- num_heads,
- head_size,
- dtype,
- seed,
- device,
- use_alibi=True,
- )
-
-
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py
index 3b8e939300a27..cd34b520ea71b 100644
--- a/tests/kernels/attention/test_attention_selector.py
+++ b/tests/kernels/attention/test_attention_selector.py
@@ -7,6 +7,7 @@ import pytest
import torch
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
+from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
@@ -33,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
}
DEVICE_REGULAR_ATTN_BACKENDS = {
- "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
+ "cuda": ["FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"],
"cpu": ["CPU_ATTN"],
}
@@ -47,9 +48,11 @@ DEVICE_MLA_BLOCK_SIZES = {
def generate_params():
+ is_rocm = current_platform.is_rocm()
params = []
+ device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
for use_mla in [True, False]:
- for device in ["cuda", "hip", "cpu"]:
+ for device in device_list:
backends = (
DEVICE_MLA_BACKENDS[device]
if use_mla
@@ -204,12 +207,6 @@ def test_env(
)
expected = "FLASHINFER"
assert backend.get_name() == expected
- elif name == "XFORMERS":
- backend = get_attn_backend(
- 32, torch.float16, None, block_size, use_mla=use_mla
- )
- expected = "XFORMERS"
- assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py
index f33a27d1fd85a..acf46d75d62eb 100644
--- a/tests/kernels/attention/test_cache.py
+++ b/tests/kernels/attention/test_cache.py
@@ -68,6 +68,7 @@ def test_copy_blocks(
pytest.skip()
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
@@ -152,6 +153,7 @@ def test_reshape_and_cache(
pytest.skip()
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
@@ -272,6 +274,7 @@ def test_reshape_and_cache_flash(
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
@@ -593,6 +596,7 @@ def test_concat_and_cache_mla(
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
@@ -662,11 +666,14 @@ def test_concat_and_cache_ds_mla(
seed: int,
device: str,
) -> None:
+ if current_platform.is_rocm():
+ pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
@@ -779,6 +786,7 @@ def test_copy_blocks_mla(
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
entry_size = kv_lora_rank + qk_rope_head_dim
@@ -843,6 +851,7 @@ def test_swap_blocks_mla(
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
+ torch.cuda.set_device(device)
entry_size = kv_lora_rank + qk_rope_head_dim
@@ -912,12 +921,16 @@ def test_gather_and_maybe_dequant_cache_mla(
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
- seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
+ seq_len_tensor = torch.randint(
+ max_seq_len, max_seq_len + 1, (batch_size,), device=device
+ )
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
+ token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
@@ -968,7 +981,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ total_tokens,
kv_cache_dtype,
scale,
None,
@@ -981,7 +995,8 @@ def test_gather_and_maybe_dequant_cache_mla(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ total_tokens,
kv_cache_dtype,
scale,
None,
diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py
index 20f573821b25f..d86041d71febd 100755
--- a/tests/kernels/attention/test_cascade_flash_attn.py
+++ b/tests/kernels/attention/test_cascade_flash_attn.py
@@ -7,11 +7,19 @@ import torch
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
-from vllm.vllm_flash_attn import (
- fa_version_unsupported_reason,
- flash_attn_varlen_func,
- is_fa_version_supported,
-)
+
+try:
+ from vllm.vllm_flash_attn import (
+ fa_version_unsupported_reason,
+ flash_attn_varlen_func,
+ is_fa_version_supported,
+ )
+except ImportError:
+ if current_platform.is_rocm():
+ pytest.skip(
+ "vllm_flash_attn is not supported for vLLM on ROCm.",
+ allow_module_level=True,
+ )
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py
index 6e5468969bf25..bbd5df5419f80 100644
--- a/tests/kernels/attention/test_flash_attn.py
+++ b/tests/kernels/attention/test_flash_attn.py
@@ -6,21 +6,30 @@ import pytest
import torch
from vllm.platforms import current_platform
-from vllm.vllm_flash_attn import (
- fa_version_unsupported_reason,
- flash_attn_varlen_func,
- is_fa_version_supported,
-)
+
+try:
+ from vllm.vllm_flash_attn import (
+ fa_version_unsupported_reason,
+ flash_attn_varlen_func,
+ is_fa_version_supported,
+ )
+except ImportError:
+ if current_platform.is_rocm():
+ pytest.skip(
+ "vllm_flash_attn is not supported for vLLM on ROCm.",
+ allow_module_level=True,
+ )
+
NUM_HEADS = [(4, 4), (8, 2)]
-HEAD_SIZES = [128, 256]
+HEAD_SIZES = [40, 72, 80, 128, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
-SOFT_CAPS = [None, 50.0]
+SOFT_CAPS = [None]
SLIDING_WINDOWS = [None, 256]
diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py
index 82ec2ef14e56c..eedeec33e0d45 100644
--- a/tests/kernels/attention/test_flashinfer.py
+++ b/tests/kernels/attention/test_flashinfer.py
@@ -2,12 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import flashinfer
import pytest
-import torch
from vllm.platforms import current_platform
+try:
+ import flashinfer
+except ImportError:
+ if current_platform.is_rocm():
+ pytest.skip(
+ "flashinfer is not supported for vLLM on ROCm.", allow_module_level=True
+ )
+
+import torch
+
NUM_HEADS = [(32, 8), (6, 1)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py
index 0350136677c6b..d183f67d3919e 100644
--- a/tests/kernels/attention/test_flashinfer_mla_decode.py
+++ b/tests/kernels/attention/test_flashinfer_mla_decode.py
@@ -3,7 +3,6 @@
import pytest
import torch
import torch.nn.functional as F
-from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from torch import Tensor
from vllm.platforms import current_platform
@@ -15,6 +14,8 @@ if not current_platform.has_device_capability(100):
reason="FlashInfer MLA Requires compute capability of 10 or above.",
allow_module_level=True,
)
+else:
+ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
def ref_mla(
diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py
index 693b849ebc5d7..98ea40608b468 100644
--- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py
+++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import flashinfer
import pytest
import torch
@@ -16,6 +15,8 @@ if not current_platform.is_device_capability(100):
pytest.skip(
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
)
+else:
+ import flashinfer
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py
index 183bbf3bf4e03..ae3c63cc62d6b 100644
--- a/tests/kernels/attention/test_mha_attn.py
+++ b/tests/kernels/attention/test_mha_attn.py
@@ -24,10 +24,6 @@ from vllm.platforms.rocm import RocmPlatform
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
- # Clear xformers availability cache
- import vllm.attention.layer as layer_module
-
- layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
@@ -62,38 +58,10 @@ def test_mha_attn_platform(device: str):
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
- # - with upstream FA not available
- # - should use xformers
+ # - should use vLLM's FlashAttention
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
- patch(
- "vllm.attention.layer.check_upstream_fa_availability",
- return_value=False,
- ),
- ):
- attn = MultiHeadAttention(16, 72, scale=1)
- assert attn.attn_backend == AttentionBackendEnum.XFORMERS
-
- # Test CUDA with head_size=72 (not divisible by 32)
- # - with upstream FA available
- # - should use upstream FA
- with (
- patch("vllm.attention.layer.current_platform", CudaPlatform()),
- patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
- patch(
- "vllm.attention.layer.check_upstream_fa_availability", return_value=True
- ),
- patch.dict(
- "sys.modules",
- {
- "flash_attn": type(
- "MockFlashAttn",
- (),
- {"flash_attn_varlen_func": lambda *args, **kwargs: None},
- )()
- },
- ),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py
index 78cdbbbf7379d..e041e8c8d2ffa 100644
--- a/tests/kernels/attention/test_prefix_prefill.py
+++ b/tests/kernels/attention/test_prefix_prefill.py
@@ -174,11 +174,11 @@ def test_contexted_kv_attention(
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
- b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
+ b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
- b_seq_start_loc = torch.cumsum(
- torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
+ b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
+ torch.int32
)
for i in range(BS):
for j in range(query_lens[i]):
@@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi(
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
- b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
+ b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
- b_seq_start_loc = torch.cumsum(
- torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
+ b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
+ torch.int32
)
for i in range(BS):
for j in range(query_lens[i]):
diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py
index 02b795721f46e..43b242ab2d586 100644
--- a/tests/kernels/core/test_mrope.py
+++ b/tests/kernels/core/test_mrope.py
@@ -5,11 +5,11 @@ from typing import NamedTuple
import pytest
import torch
from packaging.version import Version
-from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
+from vllm.transformers_utils.config import get_config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -98,8 +98,7 @@ def test_mrope(
atol = model_info.atol
rtol = model_info.rtol
- config = AutoConfig.from_pretrained(model_name)
- config = config.get_text_config()
+ config = get_config(model_name, False).get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
@@ -113,7 +112,6 @@ def test_mrope(
)
is_neox_style = True
- rope_theta = config.rope_theta
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
@@ -122,9 +120,8 @@ def test_mrope(
head_size=head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
- base=rope_theta,
is_neox_style=is_neox_style,
- rope_scaling=config.rope_scaling,
+ rope_parameters=config.rope_parameters,
dtype=dtype,
).to(device=device)
@@ -173,8 +170,7 @@ def test_mrope_torch_compile_tracing(
atol = model_info.atol
rtol = model_info.rtol
- config = AutoConfig.from_pretrained(model_name)
- config = config.get_text_config()
+ config = get_config(model_name, False).get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
@@ -187,7 +183,6 @@ def test_mrope_torch_compile_tracing(
else config.hidden_size // total_num_heads
)
is_neox_style = True
- rope_theta = config.rope_theta
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
@@ -196,9 +191,8 @@ def test_mrope_torch_compile_tracing(
head_size=head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
- base=rope_theta,
is_neox_style=is_neox_style,
- rope_scaling=config.rope_scaling,
+ rope_parameters=config.rope_parameters,
dtype=dtype,
).to(device=device)
diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py
index c35ee5016ba05..a8ed3825689d3 100644
--- a/tests/kernels/core/test_pos_encoding.py
+++ b/tests/kernels/core/test_pos_encoding.py
@@ -74,7 +74,7 @@ def test_rotary_embedding(
device: str,
use_key: bool,
max_position: int = 8192,
- base: float = 10000,
+ rope_theta: float = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
@@ -83,7 +83,8 @@ def test_rotary_embedding(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
- rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
+ rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
+ rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len))
@@ -120,9 +121,9 @@ def test_rotary_embedding(
@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
- BASES = [10000, 1000000]
- ROPE_SCALINGS = (
- None,
+ ROPE_THETAS = [10000, 1000000]
+ ROPE_PARAMETERS = (
+ {"rope_type": "default"},
{"rope_type": "linear", "factor": (1,)},
{"rope_type": "dynamic", "factor": 1},
)
@@ -130,9 +131,9 @@ def test_rope_module_cache():
HEAD_SIZES,
ROTARY_DIMS,
MAX_POSITIONS,
- BASES,
+ ROPE_THETAS,
IS_NEOX_STYLE,
- ROPE_SCALINGS,
+ ROPE_PARAMETERS,
DTYPES,
)
rope_setting_id_map: dict[str, int] = {}
@@ -141,20 +142,20 @@ def test_rope_module_cache():
head_size,
rotary_dim,
max_position,
- base,
- is_neox_stype,
- rope_scaling,
+ rope_theta,
+ is_neox_style,
+ rope_parameters,
dtype,
) = setting
if rotary_dim is None:
rotary_dim = head_size
+ rope_parameters["rope_theta"] = rope_theta
rope = get_rope(
head_size,
rotary_dim,
max_position,
- base,
- is_neox_stype,
- rope_scaling,
+ is_neox_style,
+ rope_parameters,
dtype,
)
# different settings cannot share the same rope module
@@ -168,20 +169,20 @@ def test_rope_module_cache():
head_size,
rotary_dim,
max_position,
- base,
- is_neox_stype,
- rope_scaling,
+ rope_theta,
+ is_neox_style,
+ rope_parameters,
dtype,
) = setting
if rotary_dim is None:
rotary_dim = head_size
+ rope_parameters["rope_theta"] = rope_theta
rope = get_rope(
head_size,
rotary_dim,
max_position,
- base,
- is_neox_stype,
- rope_scaling,
+ is_neox_style,
+ rope_parameters,
dtype,
)
# check if cache take effect
diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py
index 1d925dc1bea8f..d95c22fdf0a5b 100644
--- a/tests/kernels/moe/modular_kernel_tools/common.py
+++ b/tests/kernels/moe/modular_kernel_tools/common.py
@@ -15,7 +15,11 @@ from tests.kernels.quantization.nvfp4_utils import (
)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
-from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
+from vllm.distributed import (
+ get_dp_group,
+ get_pcp_group,
+ get_tensor_model_parallel_world_size,
+)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -561,6 +565,7 @@ def make_modular_kernel(
# make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=get_tensor_model_parallel_world_size(),
+ pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config,
)
diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py
index 2285709fa7d60..dab1207d78031 100644
--- a/tests/kernels/moe/test_batched_moe.py
+++ b/tests/kernels/moe/test_batched_moe.py
@@ -39,6 +39,11 @@ MNK_FACTORS = [
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
+DTYPES = [torch.bfloat16]
+
+if not current_platform.is_fp8_fnuz():
+ DTYPES.append(torch.float8_e4m3fn)
+
vllm_config = VllmConfig()
@@ -96,7 +101,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
-@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(
@@ -229,7 +234,7 @@ def test_batched_mm(
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
-@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
+@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py
index 88db4b3e537c2..b0ff1e64e3219 100644
--- a/tests/kernels/moe/test_block_fp8.py
+++ b/tests/kernels/moe/test_block_fp8.py
@@ -31,6 +31,11 @@ dg_available = has_deep_gemm()
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
vllm_config = VllmConfig()
diff --git a/tests/kernels/moe/test_cutedsl_moe.py b/tests/kernels/moe/test_cutedsl_moe.py
new file mode 100644
index 0000000000000..af1a34d17d48b
--- /dev/null
+++ b/tests/kernels/moe/test_cutedsl_moe.py
@@ -0,0 +1,582 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+import pytest
+
+from vllm.platforms import current_platform
+
+if not current_platform.has_device_capability(100):
+ pytest.skip(
+ reason="Nvfp4 Requires compute capability of 10 or above.",
+ allow_module_level=True,
+ )
+
+import torch
+from flashinfer import fp4_quantize
+from torch.nn import functional as F
+
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
+ flashinfer_cutedsl_moe_masked,
+)
+from vllm.utils.flashinfer import (
+ flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
+)
+from vllm.utils.flashinfer import (
+ scaled_fp4_grouped_quantize,
+)
+
+kE2M1ToFloat = torch.tensor(
+ [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
+)
+
+FLOAT8_E4M3_MAX = 448.0
+FLOAT4_E2M1_MAX = 6.0
+
+
+def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
+ m_tiles = (m + 128 - 1) // 128
+ f = block_size * 4
+ k_tiles = (k + f - 1) // f
+ tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
+ tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
+ out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
+ return out[0:m, 0:k]
+
+
+def dequantize_nvfp4_to_dtype(
+ tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
+):
+ """Dequantize the fp4 tensor back to high precision."""
+ # Two fp4 values are packed into one uint8.
+ assert tensor_fp4.dtype == torch.uint8
+ m, packed_k = tensor_fp4.shape
+ k = packed_k * 2
+ tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
+ tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
+ tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
+ tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
+ tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
+
+ # scale the tensor
+ out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
+ return out.to(dtype=dtype)
+
+
+def break_fp4_bytes(a, dtype):
+ assert a.dtype == torch.uint8
+ m, n = a.shape
+
+ # Vectorized nibble processing
+ a_flat = a.flatten()
+ high = (a_flat & 0xF0) >> 4 # Upper nibbles
+ low = a_flat & 0x0F # Lower nibbles
+
+ # Combine nibbles for batch processing
+ combined = torch.stack((low, high), dim=1).flatten()
+
+ # Vectorized sign and magnitude extraction
+ signs = (combined & 0x08).to(torch.bool) # Sign bits
+ abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
+
+ # Device-aware lookup and sign application
+ kE2M1 = kE2M1ToFloat.to(device=a.device)
+ values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
+
+ # Reshape to final form
+ return values.reshape(m, n * 2).to(dtype=dtype)
+
+
+def generate_balanced_routing(
+ hidden_states: torch.Tensor, num_experts: int, top_k: int
+):
+ """
+ Generate routing weights and topk indices such that every expert is active.
+ Returns routing_weights, topk_idx
+ """
+
+ num_tokens, hidden_dim = hidden_states.shape
+ # num_tokens = batch_size * seq_len
+
+ # First, assign at least one token per expert
+ tokens_per_expert = torch.arange(num_tokens) % num_experts
+ tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle
+
+ # Each token has top_k experts — start with one guaranteed expert
+ topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long)
+ topk_idx[:, 0] = tokens_per_expert
+
+ # For remaining top_k - 1 experts, pick randomly (allowing repeats)
+ if top_k > 1:
+ random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1))
+ topk_idx[:, 1:] = random_choices
+
+ # Normalize routing weights so each token's weights sum to 1
+ routing_weights = torch.rand(num_tokens, top_k)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+
+ # Reshape back if needed
+ routing_weights = routing_weights.view(num_tokens, top_k)
+ topk_idx = topk_idx.view(num_tokens, top_k)
+
+ return routing_weights, topk_idx
+
+
+def prepare_inputs(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ num_experts: int,
+ topk: int,
+):
+ routing_weights, topk_idx = generate_balanced_routing(
+ router_logits, num_experts, topk
+ )
+
+ masked_m = []
+ for i in range(num_experts):
+ mask = topk_idx.view(-1) == i
+ masked_m.append(mask.sum())
+
+ masked_m = torch.tensor(masked_m, dtype=torch.int32)
+ # Intialize the hidden_states_3d with ones instead of empty to avoid nan
+ # issue.
+ hidden_states_3d = torch.ones(
+ (num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
+ )
+ for i in range(num_experts):
+ hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
+
+ return hidden_states_3d, masked_m, topk_idx, routing_weights
+
+
+MNK_FACTORS = [
+ (2, 1024, 1024),
+ (2, 1024, 1536),
+ (2, 3072, 1024),
+ (2, 3072, 1536),
+ (64, 1024, 1024),
+ (64, 1024, 1536),
+ (64, 3072, 1024),
+ (64, 2048, 1024),
+ (224, 1024, 1024),
+ (224, 1024, 1536),
+]
+
+
+# Reference implementation of torch_moe
+def torch_moe(a, w1, w2, score, topk, expert_map):
+ B, D = a.shape
+ a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
+ out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
+ score = torch.softmax(score, dim=-1, dtype=torch.float32)
+ topk_weight, topk_ids = torch.topk(score, topk)
+ topk_weight = topk_weight.view(-1)
+ topk_ids = topk_ids.view(-1)
+ if expert_map is not None:
+ topk_ids = expert_map[topk_ids]
+ for i in range(w1.shape[0]):
+ mask = topk_ids == i
+ if mask.sum():
+ out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
+ 0, 1
+ )
+ return (
+ out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
+ ).sum(dim=1)
+
+
+def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
+ B, D = a.shape
+ a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
+ out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
+
+ topk_weight = topk_weight.view(-1)
+ topk_ids = topk_ids.view(-1)
+
+ for i in range(w1.shape[0]):
+ mask = topk_ids == i
+ if mask.sum():
+ m = w1[i].shape[0]
+ assert m % 2 == 0
+ # Note: w1 and w3 are swapped!
+ w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
+ inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
+ inter_gs = torch.tensor(1.0).cuda()
+ inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
+ inter = dequantize_nvfp4_to_dtype(
+ inter_q,
+ inter_blockscale,
+ inter_gs,
+ dtype=inter.dtype,
+ device=inter.device,
+ block_size=16,
+ ).cuda()
+ out[mask] = inter @ w2[i].transpose(0, 1)
+ return (
+ out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
+ ).sum(dim=1)
+
+
+def grouped_gemm_ref(
+ hidden_states_expanded: torch.Tensor,
+ hidden_states_3d: torch.Tensor,
+ weights: torch.Tensor,
+ topk_idx: torch.Tensor,
+ masked_m: torch.Tensor,
+ B: int,
+ topk: int,
+ num_experts: int,
+ *,
+ block_size: int = 16,
+) -> torch.Tensor:
+ """
+ Computes the reference grouped GEMM (fp4 quantized per-expert loop),
+ computes flashinfer grouped GEMM (for scale consistency),
+ and returns ONLY the repacked reference output: out_ref.
+
+ Returns:
+ out_ref: Tensor [num_experts, max_m, n_out]
+ """
+ device_hs = hidden_states_expanded.device
+ device_w = weights.device
+ out_dtype = weights.dtype
+ n_out = weights.shape[1]
+
+ # Flattened reference output (B*topk, n_out)
+ out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
+
+ # Per-expert reference compute loop
+ for i in range(num_experts):
+ mask = topk_idx.view(-1) == i
+ if mask.any():
+ lhs = hidden_states_expanded[mask]
+ rhs = weights[i]
+
+ a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
+ b_amax = rhs.abs().max().to(torch.float32).to(device_w)
+
+ a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
+ b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
+
+ lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
+ rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
+
+ lhs_in_dtype = dequantize_nvfp4_to_dtype(
+ lhsq,
+ lhsq_sf,
+ a_gs,
+ dtype=lhs.dtype,
+ device=device_hs,
+ block_size=block_size,
+ )
+ rhs_in_dtype = dequantize_nvfp4_to_dtype(
+ rhsq,
+ rhsq_sf,
+ b_gs,
+ dtype=rhs.dtype,
+ device=device_w,
+ block_size=block_size,
+ )
+
+ out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
+
+ # Determine per-expert max_m
+ max_m_val = int(masked_m.max().item())
+
+ # Repack into [num_experts, max_m, n_out]
+ out_ref = torch.zeros(
+ (num_experts, max_m_val, n_out),
+ dtype=out.dtype,
+ device=out.device,
+ )
+ expert_slot = [0] * num_experts
+
+ for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
+ slot = expert_slot[expert_id]
+ if slot < max_m_val:
+ out_ref[expert_id, slot, :] = out[i]
+ expert_slot[expert_id] += 1
+ else:
+ raise IndexError(
+ f"Expert {expert_id} exceeded max slots ({max_m_val}). "
+ "Increase max_m or check masked_m."
+ )
+
+ return out_ref
+
+
+def flashinfer_cutedsl_grouped_gemm_nt_masked(
+ hidden_states: torch.Tensor, # 3d
+ input_global_scale: torch.Tensor, # (l,)
+ weights: torch.Tensor,
+ w_global_scale: torch.Tensor, # (l,)
+ masked_m: torch.Tensor,
+):
+ # hidden_states: [l, m, k]
+ # weights: [l, n, k]
+ aq, aq_sf = scaled_fp4_grouped_quantize(
+ hidden_states,
+ masked_m.to(hidden_states.device),
+ input_global_scale,
+ )
+ num_experts, n, k = weights.shape
+ bq, bq_sf = scaled_fp4_grouped_quantize(
+ weights,
+ torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
+ w_global_scale,
+ )
+
+ out = torch.zeros(
+ (num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
+ )
+ out = out.permute(1, 2, 0) # requirement of kernel
+ sf_vec_size = 16
+ ab_dtype = "float4_e2m1fn"
+ sf_dtype = "float8_e4m3fn"
+ c_dtype = "bfloat16"
+ alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
+ 1, 1, num_experts
+ )
+
+ def get_cute_dtype(input: torch.Tensor) -> str:
+ if input.dtype == torch.bfloat16:
+ return "bfloat16"
+ elif input.dtype == torch.float16:
+ return "float16"
+ elif input.dtype == torch.float32:
+ return "float32"
+ else:
+ raise ValueError(f"Unsupported cute dtype {input.dtype}")
+
+ cutedsl_gmm_masked(
+ (aq, aq_sf),
+ (bq, bq_sf),
+ out,
+ masked_m.to(aq.device),
+ ab_dtype=ab_dtype,
+ sf_dtype=sf_dtype,
+ c_dtype=c_dtype,
+ sf_vec_size=sf_vec_size,
+ alpha=alpha,
+ alpha_dtype=get_cute_dtype(alpha),
+ )
+
+ return out
+
+
+@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
+@pytest.mark.parametrize("topk", [1, 2, 4])
+@torch.inference_mode()
+def test_flashinfer_cutedsl_moe_masked(
+ bs: int, hidden_dim: int, inter_dim: int, topk: int
+):
+ torch.manual_seed(42)
+ device = "cuda"
+ num_experts = 8
+ hidden_states = (
+ torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
+ )
+ w1 = (
+ torch.randn(
+ num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
+ )
+ / 10.0
+ )
+ w2 = (
+ torch.randn(
+ num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
+ )
+ / 10.0
+ )
+ router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
+
+ hidden_states_expanded = (
+ hidden_states.view(bs, -1, hidden_dim)
+ .repeat(1, topk, 1)
+ .reshape(-1, hidden_dim)
+ )
+ hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
+ hidden_states_expanded, router_logits, num_experts, topk
+ )
+
+ w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
+ w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
+ input_global_scale = torch.ones(
+ (num_experts,), dtype=torch.float32, device=hidden_states.device
+ )
+
+ w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
+ w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
+ a2_global_scale = torch.ones(
+ (num_experts,), dtype=torch.float32, device=hidden_states.device
+ ) # assume intermediate scale is 1.0
+
+ w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
+ w1,
+ torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
+ w1_global_scale,
+ )
+ w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
+ w2,
+ torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
+ w2_global_scale,
+ )
+
+ w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
+ w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
+
+ out = torch.empty_like(hidden_states_3d)
+ # Note: the 1st dim shouldn't be bs
+ wk = torch.empty(
+ num_experts,
+ hidden_states_3d.shape[1],
+ inter_dim * 2,
+ dtype=hidden_states_3d.dtype,
+ device=hidden_states.device,
+ )
+ flashinfer_cutedsl_moe_masked(
+ hidden_states_3d.to(hidden_states.device),
+ input_global_scale,
+ w1_fp4.permute(2, 0, 1),
+ w1_blockscale,
+ w1_alpha,
+ w2_fp4.permute(2, 0, 1),
+ a2_global_scale,
+ w2_blockscale,
+ w2_alpha,
+ masked_m.to(hidden_states.device),
+ wk,
+ out,
+ )
+
+ # reference
+ a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
+ a_in_dtype = dequantize_nvfp4_to_dtype(
+ a_fp4,
+ a_scale_interleaved,
+ input_global_scale,
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ block_size=16,
+ )
+ w1_d = torch.empty(
+ (num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
+ )
+ w2_d = torch.empty(
+ (num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
+ )
+
+ for idx in range(0, num_experts):
+ w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
+ w1[idx], w1_global_scale[idx]
+ )
+ w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
+ w2[idx], w2_global_scale[idx]
+ )
+ w1_d[idx] = dequantize_nvfp4_to_dtype(
+ w1_fp4_sliced,
+ w1_blockscale_sliced,
+ w1_global_scale[idx],
+ dtype=w1.dtype,
+ device=w1.device,
+ block_size=16,
+ )
+ w2_d[idx] = dequantize_nvfp4_to_dtype(
+ w2_fp4_sliced,
+ w2_blockscale_sliced,
+ w2_global_scale[idx],
+ dtype=w2.dtype,
+ device=w2.device,
+ block_size=16,
+ )
+
+ ref_output = torch_moe_nvfp4(
+ a_in_dtype,
+ w1_d,
+ w2_d,
+ topk,
+ routing_weights.to(a_in_dtype.device),
+ topk_idx.to(a_in_dtype.device),
+ )
+ out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
+
+ positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
+ rows, cols = positions[:, 0], positions[:, 1]
+ experts = topk_idx[rows, cols]
+ for i in range(num_experts):
+ mask = experts == i
+ if mask.any():
+ idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
+ r, c = rows[idx], cols[idx]
+ out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
+ out.device
+ ).unsqueeze(-1)
+ torch.testing.assert_close(
+ out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
+ )
+
+
+@pytest.mark.parametrize(
+ "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
+)
+@torch.inference_mode()
+def test_grouped_gemm_nt_masked(
+ bs: int, hidden_dim: int, inter_dim: int, topk: int
+) -> None:
+ torch.manual_seed(42)
+ B = bs
+ D = hidden_dim
+ N = inter_dim
+ # CuteDSL group gemm has issue when not all experts are active.
+ # i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
+ # see https://github.com/flashinfer-ai/flashinfer/issues/1856
+ num_experts = bs
+ hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
+ weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
+ router_logits = torch.randn(B, num_experts, dtype=torch.float32)
+
+ hidden_states_expanded = (
+ hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
+ )
+ hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
+ hidden_states_expanded, router_logits, num_experts, topk
+ )
+
+ a_amax = (
+ hidden_states_3d.abs()
+ .amax(dim=(1, 2))
+ .to(torch.float32)
+ .to(hidden_states.device)
+ )
+ b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
+ a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
+ b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
+ out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
+ hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
+ )
+ # reference
+ out_ref = grouped_gemm_ref(
+ hidden_states_expanded=hidden_states_expanded,
+ hidden_states_3d=hidden_states_3d,
+ weights=weights,
+ topk_idx=topk_idx,
+ masked_m=masked_m,
+ B=B,
+ topk=topk,
+ num_experts=num_experts,
+ )
+ # Note: just to compare the masked position due to cutedsl may write nan
+ # into unmasked position.
+ for i in range(num_experts):
+ torch.testing.assert_close(
+ out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
+ out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
+ atol=1e-1,
+ rtol=1e-1,
+ )
+
+
+if __name__ == "__main__":
+ test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
+ test_grouped_gemm_nt_masked(16, 128, 512, 4)
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index 218df4a2632c3..a6977f222408d 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
-from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
@@ -22,7 +21,14 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
-from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
+
+try:
+ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
+except ImportError:
+ if current_platform.is_rocm():
+ pytest.skip(
+ "flashinfer not supported for vLLM on ROCm", allow_module_level=True
+ )
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
90
@@ -144,14 +150,11 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
- router_logits=score,
- use_grouped_topk=False,
- top_k=topk,
+ gating_output=score,
+ topk=topk,
renormalize=False,
- custom_routing_function=Llama4MoE.custom_routing_function,
- scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
@@ -212,14 +215,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids = Llama4MoE.custom_routing_function(
hidden_states=td.hidden_states,
- router_logits=score,
- use_grouped_topk=False,
- top_k=topk,
+ gating_output=score,
+ topk=topk,
renormalize=False,
- custom_routing_function=Llama4MoE.custom_routing_function,
- scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py
index dfd317bcf72f1..98e80ec029777 100644
--- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py
+++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py
@@ -201,7 +201,7 @@ class ModelConfig:
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
- rope_scaling_factor: float = 32.0
+ rope_parameters_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0
@@ -270,6 +270,11 @@ class Case:
@pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp):
+ from triton_kernels.tensor_details import layout
+
+ if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
+ pytest.skip("make_default_matmul_mxfp4_w_layout not available")
+
M = num_token
E = ModelConfig.num_experts
K = ModelConfig.hidden_size
diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py
index e3b8621b452fa..2a30ef2355529 100644
--- a/tests/kernels/moe/test_modular_kernel_combinations.py
+++ b/tests/kernels/moe/test_modular_kernel_combinations.py
@@ -46,6 +46,12 @@ meets_multi_gpu_requirements = pytest.mark.skipif(
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
)
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
+
def format_result(verbose, msg, ex=None):
if ex is not None:
diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py
index ba1f657b3ecda..12dd322dccc52 100644
--- a/tests/kernels/moe/test_moe_permute_unpermute.py
+++ b/tests/kernels/moe/test_moe_permute_unpermute.py
@@ -23,6 +23,12 @@ TOP_KS = [2, 6, 8]
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)
+if current_platform.is_rocm():
+ pytest.skip(
+ "moe_permute_unpermute_supported is not defined for ROCm",
+ allow_module_level=True,
+ )
+
def torch_permute(
hidden_states: torch.Tensor,
diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
index d6b78dd2c2323..b220205759e2d 100644
--- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
+++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
@@ -14,6 +14,12 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
+
fp8_dtype = torch.float8_e4m3fn
CASES = [
diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
index 7a467e160b784..0ab025dceca40 100644
--- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
@@ -19,6 +19,12 @@ if current_platform.get_device_capability() < (9, 0):
vllm_config = VllmConfig()
+if current_platform.is_fp8_fnuz():
+ pytest.skip(
+ "Tests in this file require float8_e4m3fn and platform does not support",
+ allow_module_level=True,
+ )
+
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input
diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py
index e9973c1fcc15e..d0e4f6554a91f 100644
--- a/tests/kernels/quantization/test_block_fp8.py
+++ b/tests/kernels/quantization/test_block_fp8.py
@@ -22,6 +22,7 @@ from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8,
+ should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.import_utils import has_deep_gemm
@@ -157,10 +158,6 @@ def test_w8a8_block_fp8_cutlass_matmul():
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
- # only aligned sizes
- if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
- pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
-
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
@@ -168,6 +165,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
+ # only aligned sizes are supported by deepgemm
+ if not should_use_deepgemm_for_fp8_linear(
+ output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
+ ):
+ pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
+
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
diff --git a/tests/kernels/test_cache_kernels.py b/tests/kernels/test_cache_kernels.py
new file mode 100644
index 0000000000000..b5d66b4ede886
--- /dev/null
+++ b/tests/kernels/test_cache_kernels.py
@@ -0,0 +1,65 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for CUDA kernels in cache_kernels.cu."""
+
+import pytest
+import torch
+
+try:
+ from vllm import _custom_ops as ops
+except ImportError:
+ pytest.skip(
+ "Could not import vllm._custom_ops. (pip install -e .)", allow_module_level=True
+ )
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
+def test_gather_cache_oob():
+ """
+ Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909).
+ This test constructs a boundary case identified in the issue where
+ seq_starts causes the block_table offset to read out of bounds.
+ """
+
+ batch_size = 1
+ block_size = 64
+ entry_size = 128
+
+ block_table = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda")
+
+ # This will result in offset = 128 / block_size = 128 / 64 = 2
+ # This will cause the kernel to try to read from
+ # block_table[0, 2], but its size is only 2.
+ seq_starts = torch.tensor([128], dtype=torch.int32, device="cuda")
+
+ seq_len = 65
+ cu_seq_lens = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda")
+
+ # src_cache: [num_blocks, block_size, entry_size]
+ num_blocks = 5
+ src_cache = torch.randn(
+ (num_blocks, block_size, entry_size), dtype=torch.float16, device="cuda"
+ )
+
+ dst = torch.empty((seq_len, entry_size), dtype=torch.float16, device="cuda")
+
+ scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
+
+ # Calling the C++ function gather_and_maybe_dequant_cache
+ ops.gather_and_maybe_dequant_cache(
+ src_cache,
+ dst,
+ block_table,
+ cu_seq_lens,
+ batch_size,
+ "auto", # kv_cache_dtype
+ scale,
+ seq_starts,
+ )
+
+ torch.cuda.synchronize()
+ assert True
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py
index 5d5a26fbfc2cd..9307ef7814a8b 100644
--- a/tests/kernels/utils.py
+++ b/tests/kernels/utils.py
@@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
)
-def make_alibi_bias(
- alibi_slopes: torch.Tensor,
- num_kv_heads: int,
- dtype: torch.dtype,
- seq_lens: list[int],
-) -> list[Any]:
- """Create ALiBi biases compatible with xFormers attention tests."""
- from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
-
- if alibi_slopes is None:
- return [None for _ in seq_lens]
-
- attn_biases: list[Any] = []
- num_heads = alibi_slopes.shape[0]
- assert num_heads >= num_kv_heads, (
- "ALiBi slopes expect at least as many heads as KV heads"
- )
-
- for seq_len in seq_lens:
- bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
- bias = bias[None, :] - bias[:, None]
-
- padded_len = (seq_len + 7) // 8 * 8
- bias_tensor = torch.empty(
- 1,
- num_heads,
- seq_len,
- padded_len,
- device=alibi_slopes.device,
- dtype=dtype,
- )[:, :, :, :seq_len].copy_(bias)
- bias_tensor.mul_(alibi_slopes[:, None, None])
- attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
-
- return attn_biases
-
-
def _make_metadata_tensors(
seq_lens: list[int] | None,
context_lens: list[int] | None,
@@ -649,23 +612,12 @@ def make_kv_cache(
Returns:
- * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
- * for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
"""
- if backend == "XFORMERS":
- kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
- device
- )
- elif backend == "FLASH_ATTN":
- kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
- device
- )
- else:
- raise ValueError(
- f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
- )
+ if backend != "FLASH_ATTN":
+ raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
+ kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@@ -843,22 +795,14 @@ def assert_actual_matches_ideal(
* output_under_test: actually observed output value
"""
ideal_output = test_params.packed_qkvo.ideal_output
- if backend == "XFORMERS":
- torch.testing.assert_close(
- ideal_output, output_under_test.view_as(ideal_output)
- )
-
- elif backend == "FLASH_ATTN":
- # For FlashAttention override the accuracy thresholds to non default
- # values since we notice a higher difference between the ideal and
- # actual output.
- torch.testing.assert_close(
- ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
- )
- else:
- raise ValueError(
- f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
- )
+ if backend != "FLASH_ATTN":
+ raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
+ # For FlashAttention override the accuracy thresholds to non default
+ # values since we notice a higher difference between the ideal and
+ # actual output.
+ torch.testing.assert_close(
+ ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
+ )
# Copied/modified from torch._refs.__init__.py
diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py
index d8ff9339bb49b..9d38ec5422794 100644
--- a/tests/lora/conftest.py
+++ b/tests/lora/conftest.py
@@ -250,6 +250,16 @@ def olmoe_lora_files():
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
+@pytest.fixture(scope="session")
+def qwen3_lora_files():
+ return snapshot_download(repo_id="charent/self_cognition_Alice")
+
+
+@pytest.fixture(scope="session")
+def llama32_lora_files():
+ return snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
+
+
@pytest.fixture
def reset_default_device():
"""
diff --git a/tests/lora/test_default_mm_loras.py b/tests/lora/test_default_mm_loras.py
index dfc45e78e464f..407b29fdd1d58 100644
--- a/tests/lora/test_default_mm_loras.py
+++ b/tests/lora/test_default_mm_loras.py
@@ -5,7 +5,9 @@ Tests for applying default registered multimodal loras.
"""
import os
+import unittest.mock as mock
+import pytest
from huggingface_hub import snapshot_download
from vllm.lora.request import LoRARequest
@@ -114,3 +116,36 @@ def test_default_mm_lora_fails_with_overridden_lora_request(
default_mm_loras={"audio": IMAGE_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)
+
+
+def test_default_mm_lora_does_not_expand_string_reqs(vllm_runner):
+ class MockEngineException(Exception):
+ pass
+
+ # Regression test for ensuring default multimodal lora resolution
+ # does not expand the lora req if the prompt type is a string.
+ vllm_runner_kwargs = {
+ **VLLM_RUNNER_BASE_KWARGS,
+ **{"default_mm_loras": {"audio": AUDIO_LORA_PATH}},
+ }
+
+ # Avoid the full generation call since these tests are expensive;
+ # just check what lora request is actually submitted to the engine
+ mock_err = "Engine is mocked for this test"
+
+ with (
+ mock.patch(
+ "vllm.v1.engine.llm_engine.LLMEngine.add_request",
+ side_effect=MockEngineException(mock_err),
+ ) as mock_add_request,
+ vllm_runner(**vllm_runner_kwargs) as vllm_model,
+ ):
+ # Die once we actually submit the request to the engine
+ with pytest.raises(MockEngineException):
+ vllm_model.llm.generate(prompts=AUDIO_PROMPT)
+
+ # Then check to make sure the submitted lora request
+ # and text prompt were zipped together correctly
+ engine_args, engine_kwargs = mock_add_request.call_args
+ assert engine_kwargs["lora_request"] is None
+ assert engine_kwargs["prompt_text"] == AUDIO_PROMPT
diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py
index 91ab4a87c65f8..91c8b861c3c5c 100644
--- a/tests/lora/test_fused_moe_lora_kernel.py
+++ b/tests/lora/test_fused_moe_lora_kernel.py
@@ -1,13 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
import random
import pytest
import torch
+from tests.utils import multi_gpu_test
from vllm import _custom_ops as ops
+from vllm.distributed import (
+ init_distributed_environment,
+ initialize_model_parallel,
+ tensor_model_parallel_all_gather,
+ tensor_model_parallel_all_reduce,
+)
+from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_world_size,
+)
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
+from vllm.utils.network_utils import get_open_port
@pytest.fixture(autouse=True)
@@ -122,6 +134,8 @@ def use_fused_moe_lora_kernel(
max_loras,
num_experts,
block_size,
+ fully_sharded=False,
+ offset=0,
):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
@@ -195,10 +209,10 @@ def use_fused_moe_lora_kernel(
config["NUM_STAGES"],
config["SPLIT_K"],
mul_routed_weight,
+ fully_sharded=fully_sharded,
+ offset=offset,
)
- return output
-
def use_torch(
hidden_states,
@@ -317,3 +331,193 @@ def test_fused_moe_lora_kernel(
)
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
+
+
+@multi_gpu_test(num_gpus=2)
+@pytest.mark.parametrize("num_tokens", [100])
+@pytest.mark.parametrize("top_k_num", [6])
+@pytest.mark.parametrize("num_experts", [64])
+@pytest.mark.parametrize("max_loras", [4])
+@pytest.mark.parametrize("N", [1408])
+@pytest.mark.parametrize("K", [2048])
+@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("seed", SEED)
+@pytest.mark.parametrize("column_parallel", [True, False])
+def test_fused_moe_lora_kernel_fully_sharded(
+ num_tokens,
+ top_k_num,
+ num_experts,
+ max_loras,
+ N,
+ K,
+ max_lora_rank,
+ block_size,
+ dtype,
+ seed,
+ column_parallel,
+):
+ current_platform.seed_everything(seed)
+ # the number of randomly generated sentences.
+ num_sequences = 10
+ # generate data
+ topk_ids, topk_weights, token_lora_mapping = sample_data(
+ num_tokens, num_sequences, max_loras, num_experts, top_k_num
+ )
+
+ def run_torch_spawn(fn, nprocs):
+ torch.multiprocessing.spawn(
+ fn,
+ args=(
+ nprocs,
+ f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
+ dtype,
+ seed,
+ N,
+ K,
+ num_tokens,
+ topk_ids,
+ topk_weights,
+ token_lora_mapping,
+ max_lora_rank,
+ top_k_num,
+ max_loras,
+ num_experts,
+ block_size,
+ column_parallel,
+ ),
+ nprocs=nprocs,
+ )
+
+ run_torch_spawn(use_fused_moe_lora_kernel_tensor_parallel, nprocs=2)
+
+
+def use_fused_moe_lora_kernel_tensor_parallel(
+ local_rank,
+ world_size,
+ init_method,
+ dtype,
+ seed,
+ N,
+ K,
+ num_tokens,
+ topk_ids,
+ topk_weights,
+ token_lora_mapping,
+ max_lora_rank,
+ top_k_num,
+ max_loras,
+ num_experts,
+ block_size,
+ column_parallel,
+):
+ def _get_shard_slice(shard_size):
+ return slice(local_rank * shard_size, (local_rank + 1) * shard_size)
+
+ current_platform.seed_everything(seed)
+
+ device = torch.device(f"cuda:{local_rank}")
+ torch.cuda.set_device(device)
+ torch.set_default_device(device)
+ torch.set_default_dtype(dtype)
+
+ init_distributed_environment(
+ world_size=world_size,
+ rank=local_rank,
+ local_rank=local_rank,
+ distributed_init_method=init_method,
+ )
+ initialize_model_parallel(world_size, 1)
+ tp_size = get_tensor_model_parallel_world_size()
+
+ input_dim = K if column_parallel else N
+ output_dim = N if column_parallel else K
+
+ # init lora weights
+ lora_a = torch.rand(
+ (
+ max_loras,
+ num_experts,
+ max_lora_rank,
+ input_dim,
+ ),
+ dtype=dtype,
+ )
+ lora_b = torch.rand(
+ (
+ max_loras,
+ num_experts,
+ output_dim,
+ max_lora_rank,
+ ),
+ dtype=dtype,
+ )
+
+ hidden_states = torch.rand(
+ (
+ num_tokens,
+ input_dim,
+ ),
+ dtype=dtype,
+ )
+
+ output = torch.zeros((num_tokens, top_k_num, output_dim), dtype=dtype)
+ topk_ids = topk_ids.to(device)
+ topk_weights = topk_weights.to(device)
+ token_lora_mapping = token_lora_mapping.to(device)
+
+ ref_output = use_torch(
+ hidden_states,
+ token_lora_mapping,
+ topk_ids,
+ [lora_a],
+ [lora_b],
+ top_k_num,
+ )
+
+ if column_parallel:
+ # Column parallel (e.g. gate_up_proj): LoRA A is sliced along the rank dim,
+ # and Lora B is sliced along the output dim
+ lora_a_shard_size = max_lora_rank // tp_size
+ lora_a = lora_a[:, :, _get_shard_slice(lora_a_shard_size), :]
+ max_lora_rank = lora_a_shard_size
+ offset = 0
+
+ lora_b_shard_size = output_dim // tp_size
+ lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
+ output = output[:, :, _get_shard_slice(lora_b_shard_size)].contiguous()
+ else:
+ # Row parallel (e.g. down proj): LoRA A is sliced along the input dim,
+ # and LoRA B is sliced along the output dim
+ lora_a_shard_size = input_dim // tp_size
+ lora_a = lora_a[:, :, :, _get_shard_slice(lora_a_shard_size)]
+ hidden_states = hidden_states[:, _get_shard_slice(lora_a_shard_size)]
+
+ lora_b_shard_size = output_dim // tp_size
+ lora_b = lora_b[:, :, _get_shard_slice(lora_b_shard_size), :]
+ offset = lora_b_shard_size * local_rank
+
+ use_fused_moe_lora_kernel(
+ topk_ids,
+ topk_weights,
+ token_lora_mapping,
+ max_lora_rank,
+ top_k_num,
+ [lora_a],
+ [lora_b],
+ hidden_states,
+ output,
+ max_loras,
+ num_experts,
+ block_size,
+ fully_sharded=True,
+ offset=offset,
+ )
+
+ if column_parallel:
+ output = tensor_model_parallel_all_gather(output)
+ else:
+ output = tensor_model_parallel_all_reduce(output)
+
+ torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-1)
diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py
index 711d514a39eb3..f4269750feb6b 100644
--- a/tests/lora/test_gptoss_tp.py
+++ b/tests/lora/test_gptoss_tp.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+
import vllm
from vllm.lora.request import LoRARequest
@@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
@multi_gpu_test(num_gpus=2)
-def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
+@pytest.mark.parametrize("fully_sharded_loras", [False, True])
+def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
+ max_num_seqs=16,
tensor_parallel_size=2,
+ fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py
index 8f18f01441932..9df3a07a9e5e9 100644
--- a/tests/lora/test_layers.py
+++ b/tests/lora/test_layers.py
@@ -136,7 +136,6 @@ def populate_loras(
id_to_index: list[int | None],
layer: BaseLayerWithLoRA,
layer_weights: torch.Tensor,
- generate_embeddings_tensor: int = 0,
repeats: int = 1,
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
"""This method populates the lora layers with lora weights.
@@ -148,8 +147,6 @@ def populate_loras(
layer: the LoRAlayer to populate.
layer_weights: the PyTorch tensor containing the layer's
weights.
- generate_embeddings_tensor: whether to generate an
- embeddings tensor for each LoRA.
repeats: must only be set for column parallel packed
layers. Indicates the number of loras to compose
together to create a single lora layer.
@@ -171,7 +168,6 @@ def populate_loras(
sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
- generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[
(sublora_len * i) : (sublora_len * (i + 1)), :
@@ -185,7 +181,6 @@ def populate_loras(
slot_idx,
lora_a=lora.lora_a,
lora_b=lora.lora_b,
- embeddings_tensor=lora.embeddings_tensor,
)
lora_dict[lora_id] = lora
@@ -306,7 +301,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
id_to_index,
max_loras,
vocab_size,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_embedding(torch.cat(inputs))
@@ -344,7 +338,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
id_to_index,
max_loras,
vocab_size,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_embedding(torch.cat(inputs))
@@ -354,149 +347,6 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
-@torch.inference_mode()
-# @pytest.mark.skip(
-# reason="Fails when loras are in any slot other than the first.")
-@pytest.mark.parametrize("num_loras", [1, 2, 4])
-@pytest.mark.parametrize("device", DEVICES)
-@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
-@pytest.mark.parametrize("stage", STAGES)
-def test_embeddings_with_new_embeddings(
- dist_init, num_loras, device, vocab_size, stage
-) -> None:
- if current_platform.is_cuda_alike():
- torch.cuda.set_device(device)
-
- torch.set_default_device(device)
- max_loras = 8
- punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
- assert check_punica_wrapper(punica_wrapper)
- lora_config = LoRAConfig(
- max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
- )
-
- def create_random_embedding_layer():
- embedding = VocabParallelEmbedding(vocab_size, 256)
- embedding_data = torch.rand_like(embedding.weight.data)
- embedding.weight.data = embedding_data
- embedding.weight.data[vocab_size:, :] = 0
- expanded_embedding = VocabParallelEmbedding(
- vocab_size + lora_config.lora_extra_vocab_size * max_loras,
- 256,
- org_num_embeddings=vocab_size,
- )
- expanded_embedding.weight.data[:vocab_size, :] = embedding_data
- # We need to deepcopy the embedding as it will be modified
- # in place
- lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
- lora_embedding.create_lora_weights(max_loras, lora_config)
-
- return expanded_embedding, lora_embedding
-
- for i in range(NUM_RANDOM_SEEDS):
- set_random_seed(i)
-
- id_to_index = get_random_id_to_index(num_loras, max_loras)
- expanded_embedding, lora_embedding = create_random_embedding_layer()
- lora_dict, _ = populate_loras(
- id_to_index,
- layer=lora_embedding,
- layer_weights=torch.zeros(
- (256, vocab_size + lora_config.lora_extra_vocab_size)
- ),
- generate_embeddings_tensor=256,
- )
-
- lora_embedding.set_mapping(punica_wrapper)
- # All embeddings tensors have the same shape.
- embeddings_tensors = [
- lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
- ]
- embeddings_tensor_len = embeddings_tensors[0].shape[0]
-
- # Add empty embeddings_tensors for unoccupied lora slots.
- for _ in range(max_loras - len(embeddings_tensors)):
- embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
-
- inputs, index_mapping, prompt_mapping = create_random_inputs(
- active_lora_ids=list(lora_dict.keys()),
- num_inputs=num_loras * 3,
- input_size=(200,),
- input_range=(1, vocab_size),
- device=device,
- )
- lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
- punica_wrapper.update_metadata(
- lora_mapping,
- id_to_index,
- max_loras,
- vocab_size,
- lora_config.lora_extra_vocab_size,
- )
- original_inputs = deepcopy(inputs)
-
- # Force some of the inputs to be in the extended embeddings range
- # to guarantee that their behavior is tested.
- for input_, original_input_, lora_id in zip(
- inputs, original_inputs, prompt_mapping
- ):
- embedding_id = lora_id - 1
- input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
- original_input_[-1] = vocab_size
- input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
- original_input_[-2] = vocab_size + embeddings_tensor_len - 1
-
- expanded_embedding.weight[
- vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
- ] = torch.cat(embeddings_tensors)
-
- lora_result = lora_embedding(torch.cat(original_inputs))
-
- expected_results: list[torch.Tensor] = []
- for input_, original_input_, lora_id in zip(
- inputs, original_inputs, prompt_mapping
- ):
- lora = lora_dict[lora_id]
- result = expanded_embedding(input_)
- after_a = F.embedding(
- original_input_,
- lora.lora_a.T,
- )
- result += after_a @ lora.lora_b.T
- expected_results.append(result)
- expected_result = torch.cat(expected_results)
-
- rtol, atol = TOLERANCES[lora_result.dtype]
- torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
-
- # Check that resetting the lora weights succeeds
-
- for slot_idx in range(max_loras):
- lora_embedding.reset_lora(slot_idx)
-
- inputs, index_mapping, prompt_mapping = create_random_inputs(
- active_lora_ids=[0],
- num_inputs=num_loras * 3,
- input_size=(200,),
- input_range=(1, vocab_size),
- device=device,
- )
- original_inputs = deepcopy(inputs)
- lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
- punica_wrapper.update_metadata(
- lora_mapping,
- id_to_index,
- max_loras,
- vocab_size,
- lora_config.lora_extra_vocab_size,
- )
- lora_result = lora_embedding(torch.cat(original_inputs))
- expected_result = expanded_embedding(torch.cat(inputs))
-
- rtol, atol = TOLERANCES[lora_result.dtype]
- torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
-
-
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("device", DEVICES)
@@ -518,16 +368,13 @@ def test_lm_head_logits_processor(
def _pretest():
linear = ParallelLMHead(
- vocab_size + lora_config.lora_extra_vocab_size,
- 1024,
- vocab_size,
+ num_embeddings=vocab_size,
+ embedding_dim=1024,
params_dtype=torch.float16,
)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, vocab_size:] = 0
- logits_processor = LogitsProcessor(
- vocab_size + lora_config.lora_extra_vocab_size, vocab_size
- )
+ logits_processor = LogitsProcessor(vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
)
@@ -541,15 +388,12 @@ def test_lm_head_logits_processor(
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, logits_processor, lora_logits_processor = _pretest()
lora_logits_processor.set_mapping(punica_wrapper)
- # NOTE: all the generated loras share the same embeddings tensor.
+
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_logits_processor,
layer_weights=linear.weight,
- generate_embeddings_tensor=1024,
)
- embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
- embeddings_tensor_len = embeddings_tensor.shape[0]
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
@@ -565,7 +409,6 @@ def test_lm_head_logits_processor(
id_to_index,
max_loras,
vocab_size,
- lora_config.lora_extra_vocab_size,
)
input_ = torch.rand(20, 1024)
@@ -575,23 +418,16 @@ def test_lm_head_logits_processor(
original_lm_head = deepcopy(linear)
- linear.weight[
- logits_processor.org_vocab_size : logits_processor.org_vocab_size
- + embeddings_tensor_len
- ] = embeddings_tensor
-
- logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(
hidden_states=input_, lm_head=linear, embedding_bias=None
)
- result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
- logits_processor.org_vocab_size = vocab_size
# Check that resetting the lora weights succeeds
@@ -612,7 +448,6 @@ def test_lm_head_logits_processor(
id_to_index,
max_loras,
vocab_size,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_logits_processor._get_logits(
@@ -694,7 +529,6 @@ def test_linear_replicated(
id_to_index,
max_loras,
512,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -726,7 +560,10 @@ def test_linear_replicated(
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
- lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
+ lora_mapping,
+ id_to_index,
+ max_loras,
+ 512,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -817,7 +654,6 @@ def test_linear_parallel(
id_to_index,
max_loras,
512,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -849,7 +685,10 @@ def test_linear_parallel(
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
- lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
+ lora_mapping,
+ id_to_index,
+ max_loras,
+ 512,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -963,7 +802,6 @@ def test_column_parallel_packed(
id_to_index,
max_loras,
512,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
@@ -1000,7 +838,6 @@ def test_column_parallel_packed(
id_to_index,
max_loras,
512,
- lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py
index 7bbd1e364d19e..18704fa6e45de 100644
--- a/tests/lora/test_llama_tp.py
+++ b/tests/lora/test_llama_tp.py
@@ -13,17 +13,27 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
-MODEL_PATH = "meta-llama/Llama-2-7b-hf"
+PROMPT_TEMPLATE = """<|eot_id|><|start_header_id|>user<|end_header_id|>
+I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
+"
+##Instruction:
+candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
+Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
+The People_ID of candidate is the foreign key of People_ID of people.
+###Input:
+{context}
+###Response:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
- " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
- " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
- " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
- " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
- " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
- " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501
+ "SELECT count(*) FROM candidate",
+ "SELECT count(*) FROM candidate",
+ "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
+ "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
+MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
+
def do_sample(
llm: vllm.LLM,
@@ -32,18 +42,19 @@ def do_sample(
tensorizer_config_dict: dict | None = None,
) -> list[str]:
prompts = [
- "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
- "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
- "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
- "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
- "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
- "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501
+ PROMPT_TEMPLATE.format(context="How many candidates are there?"),
+ PROMPT_TEMPLATE.format(context="Count the number of candidates."),
+ PROMPT_TEMPLATE.format(
+ context="Which poll resource provided the most number of candidate information?" # noqa: E501
+ ),
+ PROMPT_TEMPLATE.format(
+ context="Return the poll resource associated with the most candidates."
+ ),
]
sampling_params = vllm.SamplingParams(
- temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"]
+ temperature=0, max_tokens=64, stop=["<|im_end|>"]
)
-
if tensorizer_config_dict is not None:
outputs = llm.generate(
prompts,
@@ -75,13 +86,15 @@ def do_sample(
return generated_texts
-def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = None):
+def generate_and_test(
+ llm, llama32_lora_files, tensorizer_config_dict: dict | None = None
+):
print("lora adapter created")
print("lora 1")
assert (
do_sample(
llm,
- sql_lora_files,
+ llama32_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=1,
)
@@ -92,7 +105,7 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
assert (
do_sample(
llm,
- sql_lora_files,
+ llama32_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=2,
)
@@ -104,51 +117,52 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
@create_new_process_for_each_test()
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
-def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
+def test_llama_lora(llama32_lora_files, cudagraph_specialize_lora: bool):
llm = vllm.LLM(
MODEL_PATH,
- tokenizer=sql_lora_files,
enable_lora=True,
# also test odd max_num_seqs
- max_num_seqs=13,
+ max_num_seqs=7,
+ max_model_len=1024,
max_loras=4,
compilation_config=vllm.config.CompilationConfig(
cudagraph_specialize_lora=cudagraph_specialize_lora,
),
)
- generate_and_test(llm, sql_lora_files)
+ generate_and_test(llm, llama32_lora_files)
@multi_gpu_test(num_gpus=4)
-def test_llama_lora_tp4(sql_lora_files):
+def test_llama_lora_tp4(llama32_lora_files):
llm = vllm.LLM(
MODEL_PATH,
- tokenizer=sql_lora_files,
enable_lora=True,
- max_num_seqs=16,
+ max_num_seqs=7,
+ max_model_len=1024,
max_loras=4,
tensor_parallel_size=4,
)
- generate_and_test(llm, sql_lora_files)
+ generate_and_test(llm, llama32_lora_files)
@multi_gpu_test(num_gpus=4)
-def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
+def test_llama_lora_tp4_fully_sharded_loras(llama32_lora_files):
llm = vllm.LLM(
MODEL_PATH,
- tokenizer=sql_lora_files,
enable_lora=True,
- max_num_seqs=16,
+ max_num_seqs=8,
max_loras=4,
+ max_model_len=1024,
tensor_parallel_size=4,
fully_sharded_loras=True,
)
- generate_and_test(llm, sql_lora_files)
+ generate_and_test(llm, llama32_lora_files)
@multi_gpu_test(num_gpus=2)
def test_tp2_serialize_and_deserialize_lora(
- tmp_path, sql_lora_files, sql_lora_huggingface_id
+ tmp_path,
+ llama32_lora_files,
):
# Run the tensorizing of the LoRA adapter and the model in a subprocess
# to guarantee cleanup
@@ -157,7 +171,7 @@ def test_tp2_serialize_and_deserialize_lora(
model_name = "model-rank-%03d.tensors"
model_ref = MODEL_PATH
- lora_path = sql_lora_huggingface_id
+ lora_path = llama32_lora_files
suffix = "test"
try:
result = subprocess.run(
@@ -195,12 +209,12 @@ def test_tp2_serialize_and_deserialize_lora(
loaded_llm = LLM(
model=model_ref,
- tokenizer=sql_lora_files,
load_format="tensorizer",
enable_lora=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config,
- max_num_seqs=13,
+ max_num_seqs=7,
+ max_model_len=1024,
tensor_parallel_size=2,
max_loras=2,
)
@@ -211,7 +225,7 @@ def test_tp2_serialize_and_deserialize_lora(
print("lora 1")
assert (
do_sample(
- loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
+ loaded_llm, llama32_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
)
== EXPECTED_LORA_OUTPUT
)
diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py
index e914393fee8aa..1c692630284d0 100644
--- a/tests/lora/test_lora_functions.py
+++ b/tests/lora/test_lora_functions.py
@@ -13,8 +13,8 @@ from vllm.entrypoints.openai.api_server import (
from vllm.lora.request import LoRARequest
from vllm.v1.engine.llm_engine import LLMEngine
-MODEL_PATH = "meta-llama/Llama-2-7b-hf"
-LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
+MODEL_PATH = "Qwen/Qwen3-0.6B"
+LORA_MODULE_PATH = "charent/self_cognition_Alice"
LORA_RANK = 8
diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py
index e7816031142e3..24d4dfca46d62 100644
--- a/tests/lora/test_lora_manager.py
+++ b/tests/lora/test_lora_manager.py
@@ -48,9 +48,6 @@ DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors"))
- new_embeddings = load_file(
- os.path.join(sql_lora_files, "new_embeddings.safetensors")
- )
peft_helper = PEFTHelper.from_local_dir(
sql_lora_files, max_position_embeddings=4096
@@ -60,7 +57,6 @@ def test_from_lora_tensors(sql_lora_files, device):
tensors,
peft_helper=peft_helper,
device=device,
- embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES,
)
@@ -76,18 +72,6 @@ def test_from_lora_tensors(sql_lora_files, device):
f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
)
assert lora.lora_a.shape[0] == 8
- embeddings_module = next(
- (k for k in EMBEDDING_MODULES if k in module_name), None
- )
- if embeddings_module:
- assert torch.equal(
- lora.embeddings_tensor,
- new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
- device=lora.embeddings_tensor.device
- ),
- )
- else:
- assert lora.embeddings_tensor is None
def create_lora(
@@ -552,9 +536,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
worker_adapter_manager = WorkerLoRAManager(
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
)
- worker_adapter_manager.vocab_size = (
- dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size
- )
+ worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
dummy_lora_files = f"{tmp_path}/lora_adapter"
diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py
index 1cf8ed602b6a4..e430826461a14 100644
--- a/tests/lora/test_minicpmv_tp.py
+++ b/tests/lora/test_minicpmv_tp.py
@@ -57,10 +57,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
return generated_texts
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
@@ -84,10 +80,6 @@ def test_minicpmv_lora(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
@@ -108,10 +100,6 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
@pytest.mark.skipif(
current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests"
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="MiniCPM-V dependency xformers incompatible with ROCm",
-)
@multi_gpu_test(num_gpus=4)
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py
index e659c1e1a9a07..e3c9816625ba7 100644
--- a/tests/lora/test_olmoe_tp.py
+++ b/tests/lora/test_olmoe_tp.py
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+
import vllm
from vllm.lora.request import LoRARequest
@@ -111,8 +113,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
+@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=2)
-def test_olmoe_lora_tp2(olmoe_lora_files):
+def test_olmoe_lora_tp2(olmoe_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
@@ -122,14 +125,16 @@ def test_olmoe_lora_tp2(olmoe_lora_files):
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
+ fully_sharded_loras=fully_sharded_loras,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
+@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@multi_gpu_test(num_gpus=4)
-def test_olmoe_lora_tp4(olmoe_lora_files):
+def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
@@ -139,6 +144,7 @@ def test_olmoe_lora_tp4(olmoe_lora_files):
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
+ fully_sharded_loras=fully_sharded_loras,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py
index 1800ca107a426..7d8c940100ca4 100644
--- a/tests/lora/test_qwen2vl.py
+++ b/tests/lora/test_qwen2vl.py
@@ -2,12 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
-import pytest
-
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
-from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
@@ -142,10 +139,6 @@ QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2-VL dependency xformers incompatible with ROCm",
-)
def test_qwen2vl_lora(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA"""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@@ -156,10 +149,6 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2-VL dependency xformers incompatible with ROCm",
-)
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA through beam search."""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
@@ -178,10 +167,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
)
-@pytest.mark.xfail(
- current_platform.is_rocm(),
- reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
-)
def test_qwen25vl_lora(qwen25vl_lora_files):
"""Test Qwen 2.5 VL model with LoRA"""
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)
diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py
index c97f8debd1b9a..b163559a9414d 100644
--- a/tests/lora/test_worker.py
+++ b/tests/lora/test_worker.py
@@ -20,11 +20,12 @@ from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu_worker import Worker
+MODEL_PATH = "Qwen/Qwen3-0.6B"
NUM_LORAS = 16
@patch.dict(os.environ, {"RANK": "0"})
-def test_worker_apply_lora(sql_lora_files):
+def test_worker_apply_lora(qwen3_lora_files):
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
lora_mapping = LoRAMapping([], [])
@@ -34,9 +35,10 @@ def test_worker_apply_lora(sql_lora_files):
vllm_config = VllmConfig(
model_config=ModelConfig(
- "meta-llama/Llama-2-7b-hf",
+ MODEL_PATH,
seed=0,
dtype="float16",
+ max_model_len=127,
enforce_eager=True,
),
load_config=LoadConfig(
@@ -73,7 +75,7 @@ def test_worker_apply_lora(sql_lora_files):
assert worker.list_loras() == set()
lora_requests = [
- LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS)
+ LoRARequest(str(i + 1), i + 1, qwen3_lora_files) for i in range(NUM_LORAS)
]
set_active_loras(worker, lora_requests)
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index d30b77f094665..6aba5299b5829 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -28,7 +28,6 @@ class DummyLoRAManager:
module_name: str,
weight: torch.Tensor,
rank: int = 8,
- generate_embeddings_tensor: int = 0,
):
lora = LoRALayerWeights(
module_name,
@@ -41,13 +40,6 @@ class DummyLoRAManager:
[weight.shape[0], rank], dtype=weight.dtype, device=self._device
),
)
- if generate_embeddings_tensor:
- lora.embeddings_tensor = torch.rand(
- 5,
- generate_embeddings_tensor,
- dtype=weight.dtype,
- device=self._device,
- )
self.set_module_lora(module_name, lora)
return lora
diff --git a/tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py
index f154df6dfc232..c5b3c731ffc64 100644
--- a/tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py
+++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py
@@ -19,7 +19,8 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
@pytest.mark.skipif(
- not current_platform.is_cuda(), reason="fastsafetensors requires CUDA/NVIDIA GPUs"
+ not current_platform.is_cuda_alike(),
+ reason="fastsafetensors requires NVIDIA/AMD GPUs",
)
def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
diff --git a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py
index bd216f0e41a47..1975eb61b25da 100644
--- a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py
+++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py
@@ -17,7 +17,8 @@ from vllm.platforms import current_platform
@pytest.mark.skipif(
- not current_platform.is_cuda(), reason="fastsafetensors requires CUDA/NVIDIA GPUs"
+ not current_platform.is_cuda_alike(),
+ reason="fastsafetensors requires NVIDIA/AMD GPUs",
)
def test_fastsafetensors_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
diff --git a/tests/model_executor/model_loader/test_sharded_state_loader.py b/tests/model_executor/model_loader/test_sharded_state_loader.py
index 5bb841bf2fa0e..cf06b000efb51 100644
--- a/tests/model_executor/model_loader/test_sharded_state_loader.py
+++ b/tests/model_executor/model_loader/test_sharded_state_loader.py
@@ -60,18 +60,9 @@ def llama_3p2_1b_files():
def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
llm_sharded_writer = LLM(model=input_dir, **kwargs)
- # Check which engine version is being used
- is_v1_engine = hasattr(llm_sharded_writer.llm_engine, "engine_core")
+
# Dump worker states to output directory
- if is_v1_engine:
- # For V1 engine, we need to use engine_core.save_sharded_state
- print("Using V1 engine save path")
- llm_sharded_writer.llm_engine.engine_core.save_sharded_state(path=output_dir)
- else:
- # For V0 engine
- print("Using V0 engine save path")
- model_executor = llm_sharded_writer.llm_engine.model_executor
- model_executor.save_sharded_state(path=output_dir)
+ llm_sharded_writer.llm_engine.engine_core.save_sharded_state(path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
diff --git a/tests/model_executor/test_qwen3_omni.py b/tests/model_executor/test_qwen3_omni.py
new file mode 100644
index 0000000000000..c92c61dcd3bc2
--- /dev/null
+++ b/tests/model_executor/test_qwen3_omni.py
@@ -0,0 +1,221 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from unittest.mock import Mock
+
+import pytest
+from transformers import PretrainedConfig
+
+from vllm.multimodal.processing import InputProcessingContext
+
+
+# Helper function to print input IDs with coalesced audio/video tokens.
+def print_input_ids(input_ids):
+ """
+ Print input IDs, compressing consecutive special tokens.
+ - 151675: <|audio_pad|>
+ - 151656: <|video_pad|>
+ """
+ if not input_ids:
+ print("[]")
+ return
+
+ result = []
+ i = 0
+
+ while i < len(input_ids):
+ current_id = input_ids[i]
+
+ # Check if it's a special token that should be compressed
+ if current_id in [151675, 151656]:
+ # Count consecutive occurrences
+ count = 1
+ while i + count < len(input_ids) and input_ids[i + count] == current_id:
+ count += 1
+
+ # Add compressed representation
+ token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
+ result.append(f"{token_name} * {count}")
+ i += count
+ else:
+ # Regular token, just add it
+ result.append(str(current_id))
+ i += 1
+
+ print(", ".join(result))
+
+
+@pytest.fixture
+def mock_qwen3_omni_config():
+ """Create a mock Qwen3OmniMoeThinker config."""
+ config = Mock(spec=PretrainedConfig)
+ # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
+ config.audio_token_id = 151675 # <|audio_pad|>
+ config.video_token_id = 151656 # <|video_pad|>
+ config.image_token_id = 151655 # <|image_pad|>
+ config.audio_start_token_id = 151669 # <|audio_start|>
+ config.audio_end_token_id = 151670 # <|audio_end|>
+ config.vision_start_token_id = 151652 # <|vision_start|>
+ config.position_id_per_seconds = 12.5
+
+ # Vision config
+ vision_config = Mock()
+ vision_config.spatial_merge_size = 2
+ config.vision_config = vision_config
+
+ return config
+
+
+@pytest.fixture
+def mock_processor():
+ """Create a mock HF processor."""
+ from transformers.models.whisper import WhisperFeatureExtractor
+
+ processor = Mock()
+ processor.audio_token = "<|audio_pad|>"
+ processor.image_token = "<|image_pad|>"
+ processor.video_token = "<|video_pad|>"
+
+ # Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
+ feature_extractor = WhisperFeatureExtractor()
+ processor.feature_extractor = feature_extractor
+
+ return processor
+
+
+@pytest.fixture
+def mock_tokenizer():
+ """Create a mock tokenizer."""
+ tokenizer = Mock()
+ # Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
+ tokenizer.get_vocab = Mock(
+ return_value={
+ "<|audio_pad|>": 151675,
+ "<|video_pad|>": 151656,
+ "<|image_pad|>": 151655,
+ "<|audio_start|>": 151669,
+ "<|audio_end|>": 151670,
+ "<|vision_start|>": 151652,
+ "<|vision_end|>": 151653,
+ }
+ )
+ tokenizer.encode = Mock(
+ side_effect=lambda x: {
+ "<|vision_start|>": [151652],
+ "<|vision_end|>": [151653],
+ "<|audio_start|>": [151669],
+ "<|audio_end|>": [151670],
+ "<|audio_pad|>": [151675],
+ "<|image_pad|>": [151655],
+ "<|video_pad|>": [151656],
+ }.get(x, [0])
+ )
+ tokenizer.vision_bos_token = "<|vision_start|>"
+ tokenizer.vision_eos_token = "<|vision_end|>"
+ tokenizer.audio_bos_token = "<|audio_start|>"
+ tokenizer.audio_eos_token = "<|audio_end|>"
+ return tokenizer
+
+
+@pytest.fixture
+def mock_image_processor():
+ """Create a mock image processor."""
+ image_processor = Mock()
+ image_processor.merge_size = 2
+ return image_processor
+
+
+def test_qwen3_omni_get_updates_use_audio_in_video(
+ mock_qwen3_omni_config,
+ mock_processor,
+ mock_tokenizer,
+ mock_image_processor,
+):
+ """Test the get_updates_use_audio_in_video method directly."""
+
+ from vllm.model_executor.models.qwen3_omni_moe_thinker import (
+ Qwen3OmniMoeThinkerMultiModalProcessor,
+ Qwen3OmniMoeThinkerProcessingInfo,
+ )
+
+ # Create a mock context
+ mock_ctx = Mock(spec=InputProcessingContext)
+
+ # Create processing info
+ info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
+ info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
+ info.get_hf_processor = Mock(return_value=mock_processor)
+ info.get_tokenizer = Mock(return_value=mock_tokenizer)
+ info.get_image_processor = Mock(return_value=mock_image_processor)
+
+ # Create a mock dummy_inputs builder
+ mock_dummy_inputs = Mock()
+
+ # Create the processor
+ processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
+
+ # Test parameters from reference video
+ # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
+ audio_len = 85
+ video_grid_thw = [6, 36, 64]
+ video_second_per_grid_t = 2.0
+
+ # Call the method
+ updates = processor.get_updates_use_audio_in_video(
+ thinker_config=mock_qwen3_omni_config,
+ audio_len=audio_len,
+ video_grid_thw=video_grid_thw,
+ video_second_per_grid_t=video_second_per_grid_t,
+ )
+
+ # Updated input ids should align with HF implementation.
+ # 151669,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 25,
+ # <|video_pad|> * 576, <|audio_pad|> * 10,
+ # <|video_pad|> * 1152,
+ # 151670
+ print_input_ids(updates)
+
+ # Verify structure
+ assert isinstance(updates, list)
+ assert len(updates) > 0
+
+ # Verify start and end tokens
+ audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
+ audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
+
+ assert updates[0] == audio_start_token_id
+ assert updates[-1] == audio_end_token_id
+
+ # Verify both audio and video tokens are present
+ audio_token_id = mock_qwen3_omni_config.audio_token_id
+ video_token_id = mock_qwen3_omni_config.video_token_id
+
+ audio_count = updates.count(audio_token_id)
+ video_count = updates.count(video_token_id)
+
+ assert audio_count == audio_len, (
+ f"Expected {audio_len} audio tokens, got {audio_count}"
+ )
+
+ # Calculate expected video token count
+ spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
+ height = video_grid_thw[1] // spatial_merge_size
+ width = video_grid_thw[2] // spatial_merge_size
+ expected_video_count = video_grid_thw[0] * height * width
+
+ assert video_count == expected_video_count, (
+ f"Expected {expected_video_count} video tokens, got {video_count}"
+ )
+
+ # Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
+ expected_total = 1 + audio_len + expected_video_count + 1
+ assert len(updates) == expected_total, (
+ f"Expected {expected_total} total tokens, got {len(updates)}"
+ )
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py
index 0cdb7c9a603f2..df6c2cab7814b 100644
--- a/tests/models/language/generation/test_common.py
+++ b/tests/models/language/generation/test_common.py
@@ -10,13 +10,6 @@ from ....utils import large_gpu_mark
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
-# These have unsupported head_dim for FA. We do not
-# have a clean way to fall back, so we fail with
-# a clear msg when it happens.
-# https://github.com/vllm-project/vllm/issues/14524
-# NOTE(woosuk): Skipping these tests until V1 supports them.
-# REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
-
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py
index 0ae83ec16020a..80e337d570a36 100644
--- a/tests/models/language/generation/test_mistral.py
+++ b/tests/models/language/generation/test_mistral.py
@@ -208,7 +208,7 @@ def test_mistral_format(
with vllm_runner(
model,
dtype=dtype,
- tokenizer_mode="auto",
+ tokenizer_mode="hf",
load_format="safetensors",
config_format="hf",
) as hf_format_model:
diff --git a/tests/models/language/pooling/test_nomic_max_model_len.py b/tests/models/language/pooling/test_nomic_max_model_len.py
index 88f088c603276..d6216a87a229e 100644
--- a/tests/models/language/pooling/test_nomic_max_model_len.py
+++ b/tests/models/language/pooling/test_nomic_max_model_len.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
+from typing import Any
+
import pytest
from ...utils import EmbedModelInfo
@@ -79,8 +81,8 @@ def test_set_max_model_len_illegal(model_info, vllm_runner):
@pytest.mark.parametrize("model_info", MODELS)
def test_use_rope_scaling_legal(model_info, vllm_runner):
hf_overrides = {
- "rope_theta": rope_theta,
- "rope_scaling": {
+ "rope_parameters": {
+ "rope_theta": rope_theta,
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
@@ -96,9 +98,9 @@ def test_use_rope_scaling_legal(model_info, vllm_runner):
@pytest.mark.parametrize("model_info", MODELS)
def test_use_rope_scaling_illegal(model_info, vllm_runner):
- hf_overrides = {
- "rope_theta": rope_theta,
- "rope_scaling": {
+ hf_overrides: dict[str, Any] = {
+ "rope_parameters": {
+ "rope_theta": rope_theta,
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
@@ -115,8 +117,8 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner):
pass
hf_overrides = {
- "rope_theta": rope_theta,
- "rope_scaling": {
+ "rope_parameters": {
+ "rope_theta": rope_theta,
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
diff --git a/tests/models/language/pooling_mteb_test/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py
index 0384ff82790f0..189cdbae99dcd 100644
--- a/tests/models/language/pooling_mteb_test/mteb_utils.py
+++ b/tests/models/language/pooling_mteb_test/mteb_utils.py
@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
-from collections.abc import Sequence
import mteb
import numpy as np
import requests
import torch
+from mteb.models import ModelMeta
+from mteb.types import Array
+from torch.utils.data import DataLoader
import tests.ci_envs as ci_envs
from tests.models.utils import (
@@ -27,24 +29,47 @@ MTEB_EMBED_TOL = 1e-4
# See #19344
MTEB_RERANK_TASKS = ["NFCorpus"]
-MTEB_RERANK_LANGS = ["en"]
+MTEB_RERANK_LANGS = ["eng"]
MTEB_RERANK_TOL = 2e-3
+_empty_model_meta = ModelMeta(
+ loader=None,
+ name="vllm/model",
+ revision="1",
+ release_date=None,
+ languages=None,
+ framework=[],
+ similarity_fn_name=None,
+ n_parameters=None,
+ memory_usage_mb=None,
+ max_tokens=None,
+ embed_dim=None,
+ license=None,
+ open_weights=None,
+ public_training_code=None,
+ public_training_data=None,
+ use_instructions=None,
+ training_datasets=None,
+ modalities=["text"], # 'image' can be added to evaluate multimodal models
+)
+
+
+class VllmMtebEncoder(mteb.EncoderProtocol):
+ mteb_model_meta = _empty_model_meta
-class VllmMtebEncoder(mteb.Encoder):
def __init__(self, vllm_model):
- super().__init__()
self.llm = vllm_model
self.rng = np.random.default_rng(seed=42)
def encode(
self,
- sentences: Sequence[str],
+ inputs: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
+ sentences = [text for batch in inputs for text in batch["text"]]
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
outputs = self.llm.embed(sentences, use_tqdm=False)
@@ -52,36 +77,70 @@ class VllmMtebEncoder(mteb.Encoder):
embeds = embeds[np.argsort(r)]
return embeds
+ def similarity(
+ self,
+ embeddings1: np.ndarray,
+ embeddings2: np.ndarray,
+ ) -> np.ndarray:
+ # Cosine similarity
+ norm1 = np.linalg.norm(embeddings1, axis=1, keepdims=True)
+ norm2 = np.linalg.norm(embeddings2, axis=1, keepdims=True)
+ sim = np.dot(embeddings1, embeddings2.T) / (norm1 * norm2.T)
+ return sim
+
+ def similarity_pairwise(
+ self,
+ embeddings1: Array,
+ embeddings2: Array,
+ ) -> Array:
+ # Cosine similarity
+ norm1 = np.linalg.norm(embeddings1, axis=1, keepdims=True)
+ norm2 = np.linalg.norm(embeddings2, axis=1, keepdims=True)
+ sim = np.sum(embeddings1 * embeddings2, axis=1) / (
+ norm1.flatten() * norm2.flatten()
+ )
+ return sim
+
+
+class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
+ mteb_model_meta = _empty_model_meta
+
+ def __init__(self, vllm_model):
+ self.llm = vllm_model
+ self.rng = np.random.default_rng(seed=42)
+
def predict(
self,
- sentences: list[tuple[str, str, str | None]], # query, corpus, prompt
+ inputs1: DataLoader[mteb.types.BatchedInput],
+ inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
- r = self.rng.permutation(len(sentences))
- sentences = [sentences[i] for i in r]
-
- queries = [s[0] for s in sentences]
- corpus = [s[1] for s in sentences]
+ queries = [text for batch in inputs1 for text in batch["text"]]
+ corpus = [text for batch in inputs2 for text in batch["text"]]
outputs = self.llm.score(
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
)
scores = np.array(outputs)
- scores = scores[np.argsort(r)]
return scores
-class OpenAIClientMtebEncoder(mteb.Encoder):
+class OpenAIClientMtebEncoder(VllmMtebEncoder):
def __init__(self, model_name: str, client):
- super().__init__()
self.model_name = model_name
self.client = client
self.rng = np.random.default_rng(seed=42)
- def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray:
+ def encode(
+ self,
+ inputs: DataLoader[mteb.types.BatchedInput],
+ *args,
+ **kwargs,
+ ) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
+ sentences = [text for batch in inputs for text in batch["text"]]
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
@@ -94,28 +153,29 @@ class OpenAIClientMtebEncoder(mteb.Encoder):
return embeds
-class ScoreClientMtebEncoder(mteb.Encoder):
+class ScoreClientMtebEncoder(mteb.CrossEncoderProtocol):
+ mteb_model_meta = _empty_model_meta
+
def __init__(self, model_name: str, url):
- super().__init__()
self.model_name = model_name
self.url = url
self.rng = np.random.default_rng(seed=42)
def predict(
self,
- sentences: list[tuple[str, str, str | None]], # query, corpus, prompt
+ inputs1: DataLoader[mteb.types.BatchedInput],
+ inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
- r = self.rng.permutation(len(sentences))
- sentences = [sentences[i] for i in r]
+ queries = [text for batch in inputs1 for text in batch["text"]]
+ full_corpus = [text for batch in inputs2 for text in batch["text"]]
outputs = []
- for query, corpus, prompt in sentences:
+ for query, corpus in zip(queries, full_corpus):
outputs.append(self.get_score(query, corpus))
scores = np.array(outputs)
- scores = scores[np.argsort(r)]
return scores
def get_score(self, query, corpus):
@@ -145,16 +205,13 @@ class RerankClientMtebEncoder(ScoreClientMtebEncoder):
return response["results"][0]["relevance_score"]
-def run_mteb_embed_task(encoder, tasks):
+def run_mteb_embed_task(encoder: mteb.EncoderProtocol, tasks):
tasks = mteb.get_tasks(tasks=tasks)
- evaluation = mteb.MTEB(tasks=tasks)
- results = evaluation.run(
+ results = mteb.evaluate(
encoder,
- verbosity=0,
- output_folder=None,
- encode_kwargs={
- "show_progress_bar": False,
- },
+ tasks,
+ cache=None,
+ show_progress_bar=False,
)
main_score = results[0].scores["test"][0]["main_score"]
@@ -244,33 +301,39 @@ def mteb_test_embed_models(
assert st_main_score - vllm_main_score < atol
-def run_mteb_rerank(cross_encoder, tasks, languages):
- with tempfile.TemporaryDirectory() as results_folder:
+def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
+ with tempfile.TemporaryDirectory() as prediction_folder:
bm25s = mteb.get_model("bm25s")
- tasks = mteb.get_tasks(tasks=tasks, languages=languages)
-
- subset = "default"
eval_splits = ["test"]
- evaluation = mteb.MTEB(tasks=tasks)
- evaluation.run(
- bm25s,
- verbosity=0,
- eval_splits=eval_splits,
- save_predictions=True,
- output_folder=f"{results_folder}/stage1",
- encode_kwargs={"show_progress_bar": False},
+ mteb_tasks: list[mteb.abstasks.AbsTaskRetrieval] = mteb.get_tasks(
+ tasks=tasks, languages=languages, eval_splits=eval_splits
)
- results = evaluation.run(
+ mteb.evaluate(
+ bm25s,
+ mteb_tasks,
+ prediction_folder=prediction_folder,
+ show_progress_bar=False,
+ # don't save results for test runs
+ cache=None,
+ overwrite_strategy="always",
+ )
+
+ second_stage_tasks = []
+ for task in mteb_tasks:
+ second_stage_tasks.append(
+ task.convert_to_reranking(
+ prediction_folder,
+ top_k=10,
+ )
+ )
+
+ results = mteb.evaluate(
cross_encoder,
- verbosity=0,
- eval_splits=eval_splits,
- top_k=10,
- save_predictions=True,
- output_folder=f"{results_folder}/stage2",
- previous_results=f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json",
- encode_kwargs={"show_progress_bar": False},
+ second_stage_tasks,
+ show_progress_bar=False,
+ cache=None,
)
main_score = results[0].scores["test"][0]["main_score"]
return main_score
@@ -280,20 +343,6 @@ def mteb_test_rerank_models_hf(
hf_runner, model_name, hf_dtype="float32", hf_model_callback=None
):
with hf_runner(model_name, is_cross_encoder=True, dtype=hf_dtype) as hf_model:
- original_predict = hf_model.predict
-
- def _predict(
- sentences: list[tuple[str, str, str | None]], # query, corpus, prompt
- *args,
- **kwargs,
- ):
- # vllm and st both remove the prompt, fair comparison.
- prompts = [(s[0], s[1]) for s in sentences]
- return original_predict(prompts, *args, **kwargs, batch_size=8)
-
- hf_model.predict = _predict
- hf_model.original_predict = original_predict
-
if hf_model_callback is not None:
hf_model_callback(hf_model)
@@ -310,7 +359,7 @@ def mteb_test_rerank_models(
model_info: RerankModelInfo,
vllm_extra_kwargs=None,
hf_model_callback=None,
- vllm_mteb_encoder=VllmMtebEncoder,
+ vllm_mteb_encoder=VllmMtebCrossEncoder,
atol=MTEB_RERANK_TOL,
):
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
diff --git a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py
index 2927a37111364..6b2e469644926 100644
--- a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py
+++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py
@@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
+import mteb
import numpy as np
import pytest
import torch
+from torch.utils.data import DataLoader
from tests.conftest import HfRunner
from tests.models.language.pooling_mteb_test.mteb_utils import (
- VllmMtebEncoder,
+ VllmMtebCrossEncoder,
mteb_test_rerank_models,
)
from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo
@@ -103,7 +105,7 @@ class GemmaRerankerHfRunner(HfRunner):
return torch.Tensor(scores)
-class GemmaMtebEncoder(VllmMtebEncoder):
+class GemmaMtebEncoder(VllmMtebCrossEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.query_template = "A: {query}\n"
@@ -111,17 +113,26 @@ class GemmaMtebEncoder(VllmMtebEncoder):
def predict(
self,
- sentences: list[tuple[str, str, str | None]], # query, corpus, prompt
+ inputs1: DataLoader[mteb.types.BatchedInput],
+ inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
- _sentences = []
- for query, corpus, prompt in sentences:
- query = self.query_template.format(query=query)
- corpus = self.document_template.format(doc=corpus, prompt=PROMPT)
- _sentences.append((query, corpus, prompt))
-
- return super().predict(_sentences, *args, **kwargs)
+ queries = [
+ self.query_template.format(query=text)
+ for batch in inputs1
+ for text in batch["text"]
+ ]
+ corpus = [
+ self.document_template.format(doc=text, prompt=PROMPT)
+ for batch in inputs2
+ for text in batch["text"]
+ ]
+ outputs = self.llm.score(
+ queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
+ )
+ scores = np.array(outputs)
+ return scores
@pytest.mark.parametrize("model_info", RERANK_MODELS)
diff --git a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py
index fd04dc1990238..a6f2a89b268f1 100644
--- a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py
+++ b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py
@@ -70,8 +70,9 @@ class MxbaiRerankerHfRunner(HfRunner):
return scores
scores = []
- for prompt in prompts:
- inputs = process_inputs([prompt])
+ for query, doc, *_ in prompts:
+ pairs = [(query, doc)]
+ inputs = process_inputs(pairs)
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)
diff --git a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py
index 00e99f44cfdb1..9a1be6c0be1d6 100644
--- a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py
+++ b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py
@@ -72,8 +72,9 @@ class Qwen3RerankerHfRunner(HfRunner):
return scores
scores = []
- for prompt in prompts:
- inputs = process_inputs([prompt])
+ for query, doc, *_ in prompts:
+ pairs = [(query, doc)]
+ inputs = process_inputs(pairs)
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)
diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py
index e10b8e1e77af1..e1b7dbf99f1fd 100644
--- a/tests/models/multimodal/generation/test_qwen2_vl.py
+++ b/tests/models/multimodal/generation/test_qwen2_vl.py
@@ -128,12 +128,7 @@ def batch_make_image_embeddings(
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
- image_grid_thw_on_device = image_grid_thw.to(
- visual.device, dtype=torch.int64
- )
- return visual(
- pixel_values_on_device, grid_thw=image_grid_thw_on_device
- ).cpu()
+ return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
@@ -217,12 +212,7 @@ def batch_make_video_embeddings(
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
- video_grid_thw_on_device = video_grid_thw.to(
- visual.device, dtype=torch.int64
- )
- return visual(
- pixel_values_on_device, grid_thw=video_grid_thw_on_device
- ).cpu()
+ return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py
index 2f38dc450ef96..0d2eaca95504e 100644
--- a/tests/models/multimodal/test_mapping.py
+++ b/tests/models/multimodal/test_mapping.py
@@ -50,12 +50,24 @@ def test_hf_model_weights_mapper(model_arch: str):
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
+ is_mistral_model = model_arch in [
+ "Mistral3ForConditionalGeneration",
+ "PixtralForConditionalGeneration",
+ "VoxtralForConditionalGeneration",
+ ]
+
+ if not is_mistral_model or model_info.tokenizer_mode == "mistral":
+ tokenizer_mode = model_info.tokenizer_mode
+ else:
+ tokenizer_mode = "hf"
+
model_id = model_info.default
model_config = ModelConfig(
model_id,
tokenizer=model_info.tokenizer or model_id,
- tokenizer_mode=model_info.tokenizer_mode,
+ tokenizer_mode=tokenizer_mode,
+ config_format="hf",
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py
index dc4b4546e451b..5b8aaa299fdc1 100644
--- a/tests/models/quantization/test_bitsandbytes.py
+++ b/tests/models/quantization/test_bitsandbytes.py
@@ -259,6 +259,9 @@ def validate_generated_texts(
tensor_parallel_size=vllm_tp_size,
enforce_eager=False,
default_torch_num_threads=1,
+ tokenizer_mode="hf",
+ load_format="hf",
+ config_format="hf",
) as llm:
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
diff --git a/tests/models/registry.py b/tests/models/registry.py
index b33f3ab2b5a11..fdbce31d507b5 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -370,7 +370,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
- "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
+ "Olmo3ForCausalLM": _HfExamplesInfo("allenai/Olmo-3-7B-Instruct"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OpenPanguMTPModel": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
@@ -407,6 +407,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"pfnet/plamo-2-1b",
trust_remote_code=True,
),
+ "Plamo3ForCausalLM": _HfExamplesInfo(
+ "pfnet/plamo-3-nict-2b-base",
+ trust_remote_code=True,
+ ),
"QWenLMHeadModel": _HfExamplesInfo(
"Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",
@@ -627,6 +631,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True,
),
+ "HunYuanVLForConditionalGeneration": _HfExamplesInfo(
+ "tencent/HunyuanOCR",
+ is_available_online=False,
+ ),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
@@ -726,6 +734,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"NemotronH_Nano_VL_V2": _HfExamplesInfo(
"nano_vl_dummy", is_available_online=False, trust_remote_code=True
),
+ "OpenCUAForConditionalGeneration": _HfExamplesInfo(
+ "xlangai/OpenCUA-7B", trust_remote_code=True
+ ),
"Ovis": _HfExamplesInfo(
"AIDC-AI/Ovis2-1B",
trust_remote_code=True,
diff --git a/tests/multimodal/assets/corrupted.mp4 b/tests/multimodal/assets/corrupted.mp4
new file mode 100644
index 0000000000000..c355bb932ceee
Binary files /dev/null and b/tests/multimodal/assets/corrupted.mp4 differ
diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py
index 6572616769a91..6ed21de368ac3 100644
--- a/tests/multimodal/test_video.py
+++ b/tests/multimodal/test_video.py
@@ -18,6 +18,7 @@ from .utils import cosine_similarity, create_video_from_image, normalize_image
pytestmark = pytest.mark.cpu_test
+ASSETS_DIR = Path(__file__).parent / "assets"
NUM_FRAMES = 10
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@@ -140,3 +141,39 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
)
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99
+
+
+def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
+ """
+ Regression test for handling videos with broken frames.
+ This test uses a pre-corrupted video file (assets/corrupted.mp4) that
+ contains broken/unreadable frames to verify the video loader handles
+ them gracefully without crashing and returns accurate metadata.
+ """
+ with monkeypatch.context() as m:
+ m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv")
+
+ # Load the pre-corrupted video file that contains broken frames
+ corrupted_video_path = ASSETS_DIR / "corrupted.mp4"
+
+ with open(corrupted_video_path, "rb") as f:
+ video_data = f.read()
+
+ loader = VIDEO_LOADER_REGISTRY.load("opencv")
+ frames, metadata = loader.load_bytes(video_data, num_frames=-1)
+
+ # Verify metadata consistency:
+ # frames_indices must match actual loaded frames
+ assert frames.shape[0] == len(metadata["frames_indices"]), (
+ f"Frames array size must equal frames_indices length. "
+ f"Got {frames.shape[0]} frames but "
+ f"{len(metadata['frames_indices'])} indices"
+ )
+
+ # Verify that broken frames were skipped:
+ # loaded frames should be less than total
+ assert frames.shape[0] < metadata["total_num_frames"], (
+ f"Should load fewer frames than total due to broken frames. "
+ f"Expected fewer than {metadata['total_num_frames']} frames, "
+ f"but loaded {frames.shape[0]} frames"
+ )
diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py
index fb8d6130c3779..f35c3973ab6e6 100644
--- a/tests/quantization/test_torchao.py
+++ b/tests/quantization/test_torchao.py
@@ -225,13 +225,12 @@ def test_reload_weights():
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
- "torchao tests that requires newer versions (0.14.0.dev+) for now"
+ "torchao tests that requires newer versions (0.15.0.dev+) for now"
)
-def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner):
+def test_safetensors_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
- model_name = (
- "torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
- )
+ # using this model to test safetensors loading with file sharding
+ model_name = "torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors"
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
diff --git a/tests/test_config.py b/tests/test_config.py
index bba2fbec3db29..16f68d18fc68b 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -249,45 +249,48 @@ def test_get_bert_tokenization_sentence_transformer_config():
def test_rope_customization():
- TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
- TEST_ROPE_THETA = 16_000_000.0
- LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
+ TEST_ROPE_PARAMETERS = {
+ "rope_theta": 16_000_000.0,
+ "rope_type": "dynamic",
+ "factor": 2.0,
+ }
+ LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
+ LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
- assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
- assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
+ assert (
+ getattr(llama_model_config.hf_config, "rope_parameters", None)
+ == LLAMA_ROPE_PARAMETERS
+ )
assert llama_model_config.max_model_len == 8192
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
- hf_overrides={
- "rope_scaling": TEST_ROPE_SCALING,
- "rope_theta": TEST_ROPE_THETA,
- },
+ hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
)
assert (
- getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING
+ getattr(llama_model_config.hf_config, "rope_parameters", None)
+ == TEST_ROPE_PARAMETERS
)
- assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
- # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
+ # Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
assert all(
- longchat_model_config.hf_config.rope_scaling.get(key) == value
- for key, value in LONGCHAT_ROPE_SCALING.items()
+ longchat_model_config.hf_config.rope_parameters.get(key) == value
+ for key, value in LONGCHAT_ROPE_PARAMETERS.items()
)
assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
hf_overrides={
- "rope_scaling": TEST_ROPE_SCALING,
+ "rope_parameters": TEST_ROPE_PARAMETERS,
},
)
assert (
- getattr(longchat_model_config.hf_config, "rope_scaling", None)
- == TEST_ROPE_SCALING
+ getattr(longchat_model_config.hf_config, "rope_parameters", None)
+ == TEST_ROPE_PARAMETERS
)
assert longchat_model_config.max_model_len == 4096
diff --git a/tests/test_logger.py b/tests/test_logger.py
index 01672358902f9..8900e9c2a1e69 100644
--- a/tests/test_logger.py
+++ b/tests/test_logger.py
@@ -49,10 +49,13 @@ def test_trace_function_call():
os.remove(path)
-def test_default_vllm_root_logger_configuration():
+def test_default_vllm_root_logger_configuration(monkeypatch):
"""This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and
VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default
behavior is activated."""
+ monkeypatch.setenv("VLLM_LOGGING_COLOR", "0")
+ _configure_vllm_root_logger()
+
logger = logging.getLogger("vllm")
assert logger.level == logging.DEBUG
assert not logger.propagate
@@ -70,12 +73,13 @@ def test_default_vllm_root_logger_configuration():
assert formatter.datefmt == _DATE_FORMAT
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
-@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None)
-def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger():
+def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(monkeypatch):
"""This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and
VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default
behavior is activated."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "1")
+ monkeypatch.delenv("VLLM_LOGGING_CONFIG_PATH", raising=False)
+
root_logger = logging.getLogger("vllm")
root_handler = root_logger.handlers[0]
@@ -99,49 +103,50 @@ def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger():
assert log_record.levelno == logging.INFO
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0)
-@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None)
-def test_logger_configuring_can_be_disabled():
+def test_logger_configuring_can_be_disabled(monkeypatch):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "0")
+ monkeypatch.delenv("VLLM_LOGGING_CONFIG_PATH", raising=False)
with patch("vllm.logger.dictConfig") as dict_config_mock:
_configure_vllm_root_logger()
dict_config_mock.assert_not_called()
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
-@patch(
- "vllm.logger.VLLM_LOGGING_CONFIG_PATH",
- "/if/there/is/a/file/here/then/you/did/this/to/yourself.json",
-)
-def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist():
+def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(monkeypatch):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "1")
+ monkeypatch.setenv(
+ "VLLM_LOGGING_CONFIG_PATH",
+ "/if/there/is/a/file/here/then/you/did/this/to/yourself.json",
+ )
+
with pytest.raises(RuntimeError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == RuntimeError # noqa: E721
assert "File does not exist" in str(ex_info)
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
-def test_an_error_is_raised_when_custom_logging_config_is_invalid_json():
+def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(monkeypatch):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "1")
+
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write("---\nloggers: []\nversion: 1")
logging_config_file.flush()
- with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name):
- with pytest.raises(JSONDecodeError) as ex_info:
- _configure_vllm_root_logger()
- assert ex_info.type == JSONDecodeError
- assert "Expecting value" in str(ex_info)
+ monkeypatch.setenv("VLLM_LOGGING_CONFIG_PATH", logging_config_file.name)
+ with pytest.raises(JSONDecodeError) as ex_info:
+ _configure_vllm_root_logger()
+ assert ex_info.type == JSONDecodeError
+ assert "Expecting value" in str(ex_info)
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
@pytest.mark.parametrize(
"unexpected_config",
(
@@ -151,26 +156,30 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json():
),
)
def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
+ monkeypatch,
unexpected_config: Any,
):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "1")
+
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write(json.dumps(unexpected_config))
logging_config_file.flush()
- with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name):
- with pytest.raises(ValueError) as ex_info:
- _configure_vllm_root_logger()
- assert ex_info.type == ValueError # noqa: E721
- assert "Invalid logging config. Expected dict, got" in str(ex_info)
+ monkeypatch.setenv("VLLM_LOGGING_CONFIG_PATH", logging_config_file.name)
+ with pytest.raises(ValueError) as ex_info:
+ _configure_vllm_root_logger()
+ assert ex_info.type == ValueError # noqa: E721
+ assert "Invalid logging config. Expected dict, got" in str(ex_info)
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
-def test_custom_logging_config_is_parsed_and_used_when_provided():
+def test_custom_logging_config_is_parsed_and_used_when_provided(monkeypatch):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "1")
+
valid_logging_config = {
"loggers": {
"vllm.test_logger.logger": {
@@ -183,19 +192,18 @@ def test_custom_logging_config_is_parsed_and_used_when_provided():
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write(json.dumps(valid_logging_config))
logging_config_file.flush()
- with (
- patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name),
- patch("vllm.logger.dictConfig") as dict_config_mock,
- ):
+ monkeypatch.setenv("VLLM_LOGGING_CONFIG_PATH", logging_config_file.name)
+ with patch("vllm.logger.dictConfig") as dict_config_mock:
_configure_vllm_root_logger()
dict_config_mock.assert_called_with(valid_logging_config)
-@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0)
-def test_custom_logging_config_causes_an_error_if_configure_logging_is_off():
+def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(monkeypatch):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
+ monkeypatch.setenv("VLLM_CONFIGURE_LOGGING", "0")
+
valid_logging_config = {
"loggers": {
"vllm.test_logger.logger": {
@@ -207,15 +215,15 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off():
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write(json.dumps(valid_logging_config))
logging_config_file.flush()
- with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name):
- with pytest.raises(RuntimeError) as ex_info:
- _configure_vllm_root_logger()
- assert ex_info.type is RuntimeError
- expected_message_snippet = (
- "VLLM_CONFIGURE_LOGGING evaluated to false, but "
- "VLLM_LOGGING_CONFIG_PATH was given."
- )
- assert expected_message_snippet in str(ex_info)
+ monkeypatch.setenv("VLLM_LOGGING_CONFIG_PATH", logging_config_file.name)
+ with pytest.raises(RuntimeError) as ex_info:
+ _configure_vllm_root_logger()
+ assert ex_info.type is RuntimeError
+ expected_message_snippet = (
+ "VLLM_CONFIGURE_LOGGING evaluated to false, but "
+ "VLLM_LOGGING_CONFIG_PATH was given."
+ )
+ assert expected_message_snippet in str(ex_info)
# Remember! The root logger is assumed to have been configured as
# though VLLM_CONFIGURE_LOGGING=1 and VLLM_LOGGING_CONFIG_PATH=None.
diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py
index 5a162fa8f791b..e8826eb441a24 100644
--- a/tests/test_routing_simulator.py
+++ b/tests/test_routing_simulator.py
@@ -9,9 +9,16 @@ different routing strategies and analyze their performance, including
integration tests with FusedMoE layer.
"""
+import tempfile
+
import pytest
import torch
+from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.distributed import (
+ init_distributed_environment,
+ initialize_model_parallel,
+)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
DistributionBasedRouting,
RoutingSimulator,
@@ -89,6 +96,28 @@ def test_routing_strategy_integration(monkeypatch, device):
# Test different routing strategies
strategies = RoutingSimulator.get_available_strategies()
+ vllm_config = VllmConfig()
+ with set_current_vllm_config(vllm_config):
+ temp_file = tempfile.mkstemp()[1]
+ init_distributed_environment(
+ world_size=1,
+ rank=0,
+ local_rank=0,
+ distributed_init_method=f"file://{temp_file}",
+ )
+ initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ )
+ fused_moe = FusedMoE(
+ num_experts=num_experts,
+ top_k=top_k,
+ hidden_size=hidden_size,
+ intermediate_size=0,
+ use_grouped_topk=False,
+ renormalize=True,
+ )
+
for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
@@ -98,13 +127,9 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = fused_moe.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
- top_k=top_k,
- use_grouped_topk=False,
- renormalize=True,
- indices_type=torch.long,
)
# Verify output shapes
diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py
index 9af94a6a64a25..77084ec2d9456 100644
--- a/tests/tool_use/test_parallel_tool_calls.py
+++ b/tests/tool_use/test_parallel_tool_calls.py
@@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results(
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content
+
+
+@pytest.mark.asyncio
+async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
+ """
+ Ensure only one tool call is returned when parallel_tool_calls is False.
+ """
+
+ models = await client.models.list()
+ model_name: str = models.data[0].id
+ chat_completion = await client.chat.completions.create(
+ messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
+ temperature=0,
+ max_completion_tokens=200,
+ model=model_name,
+ tools=[WEATHER_TOOL, SEARCH_TOOL],
+ logprobs=False,
+ parallel_tool_calls=False,
+ )
+
+ stop_reason = chat_completion.choices[0].finish_reason
+ non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
+
+ # make sure only 1 tool call is present
+ assert len(non_streamed_tool_calls) == 1
+ assert stop_reason == "tool_calls"
+
+ # make the same request, streaming
+ stream = await client.chat.completions.create(
+ model=model_name,
+ messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
+ temperature=0,
+ max_completion_tokens=200,
+ tools=[WEATHER_TOOL, SEARCH_TOOL],
+ logprobs=False,
+ parallel_tool_calls=False,
+ stream=True,
+ )
+
+ finish_reason_count: int = 0
+ tool_call_id_count: int = 0
+
+ async for chunk in stream:
+ # if there's a finish reason make sure it's tools
+ if chunk.choices[0].finish_reason:
+ finish_reason_count += 1
+ assert chunk.choices[0].finish_reason == "tool_calls"
+
+ streamed_tool_calls = chunk.choices[0].delta.tool_calls
+ if streamed_tool_calls and len(streamed_tool_calls) > 0:
+ tool_call = streamed_tool_calls[0]
+ if tool_call.id:
+ tool_call_id_count += 1
+
+ # make sure only 1 streaming tool call is present
+ assert tool_call_id_count == 1
+ assert finish_reason_count == 1
diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py
index 38def6f874d7d..7584b903156b7 100644
--- a/tests/tool_use/utils.py
+++ b/tests/tool_use/utils.py
@@ -128,6 +128,12 @@ CONFIGS: dict[str, ServerConfig] = {
"arguments": [
"--enforce-eager",
"--no-enable-prefix-caching",
+ "--tokenizer_mode",
+ "hf",
+ "--load_format",
+ "hf",
+ "--config_format",
+ "hf",
"--tool-call-parser",
"mistral",
"--chat-template",
@@ -140,21 +146,22 @@ CONFIGS: dict[str, ServerConfig] = {
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
},
- # V1 Test: Passing locally but failing in CI. This runs the
- # V0 Engine because of CPU offloading. Need to debug why.
+ # FIXME: This test currently fails, need to debug why.
# "granite20b": {
- # "model":
- # "mbayser/granite-20b-functioncalling-FP8-KV",
+ # "model": "mbayser/granite-20b-functioncalling-FP8-KV",
# "arguments": [
- # "--tool-call-parser", "granite-20b-fc", "--chat-template",
- # str(VLLM_PATH /
- # "examples/tool_chat_template_granite_20b_fc.jinja"),
- # "--max_num_seqs", "1", "--enforce-eager", "--cpu-offload-gb", "20"
+ # "--tool-call-parser",
+ # "granite-20b-fc",
+ # "--chat-template",
+ # str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja"),
+ # "--max_num_seqs",
+ # "1",
+ # "--enforce-eager",
+ # "--cpu-offload-gb",
+ # "20",
# ],
- # "supports_parallel":
- # False,
- # "supports_rocm":
- # False,
+ # "supports_parallel": False,
+ # "supports_rocm": False,
# },
"granite-3.0-8b": {
"model": "ibm-granite/granite-3.0-8b-instruct",
diff --git a/tests/transformers_utils/test_config.py b/tests/transformers_utils/test_config.py
new file mode 100644
index 0000000000000..de28ab5f99e8c
--- /dev/null
+++ b/tests/transformers_utils/test_config.py
@@ -0,0 +1,62 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+import tempfile
+from pathlib import Path
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+
+from vllm.transformers_utils.config import list_filtered_repo_files
+
+
+@pytest.mark.parametrize(
+ "allow_patterns,expected_relative_files",
+ [
+ (
+ ["*.json", "correct*.txt"],
+ ["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
+ ),
+ ],
+)
+def test_list_filtered_repo_files(
+ allow_patterns: list[str], expected_relative_files: list[str]
+):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Prep folder and files
+ path_tmp_dir = Path(tmp_dir)
+ subfolder = path_tmp_dir / "subfolder"
+ subfolder.mkdir()
+ (path_tmp_dir / "json_file.json").touch()
+ (path_tmp_dir / "correct_2.txt").touch()
+ (path_tmp_dir / "uncorrect.txt").touch()
+ (path_tmp_dir / "uncorrect.jpeg").touch()
+ (subfolder / "correct.txt").touch()
+ (subfolder / "uncorrect_sub.txt").touch()
+
+ def _glob_path() -> list[str]:
+ return [
+ str(file.relative_to(path_tmp_dir))
+ for file in path_tmp_dir.glob("**/*")
+ if file.is_file()
+ ]
+
+ # Patch list_repo_files called by fn
+ with patch(
+ "vllm.transformers_utils.config.list_repo_files",
+ MagicMock(return_value=_glob_path()),
+ ) as mock_list_repo_files:
+ out_files = sorted(
+ list_filtered_repo_files(
+ tmp_dir, allow_patterns, "revision", "model", "token"
+ )
+ )
+ assert out_files == sorted(expected_relative_files)
+ assert mock_list_repo_files.call_count == 1
+ assert mock_list_repo_files.call_args_list[0] == call(
+ repo_id=tmp_dir,
+ revision="revision",
+ repo_type="model",
+ token="token",
+ )
diff --git a/tests/transformers_utils/test_utils.py b/tests/transformers_utils/test_utils.py
index beaef04d766bf..bfe1cec76c138 100644
--- a/tests/transformers_utils/test_utils.py
+++ b/tests/transformers_utils/test_utils.py
@@ -2,7 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from vllm.transformers_utils.utils import is_cloud_storage, is_gcs, is_s3
+from vllm.transformers_utils.utils import (
+ is_cloud_storage,
+ is_gcs,
+ is_s3,
+)
def test_is_gcs():
diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py
index 1bd05e6183dc2..783e02ce89bdb 100644
--- a/tests/v1/attention/test_mla_backends.py
+++ b/tests/v1/attention/test_mla_backends.py
@@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
- supported_sizes = backend.get_class().supported_kernel_block_sizes
+ supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py
new file mode 100644
index 0000000000000..80158d4b7278c
--- /dev/null
+++ b/tests/v1/attention/test_rocm_attention_backends_selection.py
@@ -0,0 +1,343 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for attention backend selectors."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.platforms import current_platform
+
+# ROCm-specific attention backend selection tests
+pytestmark = pytest.mark.skipif(
+ not current_platform.is_rocm(), reason="ROCm-specific tests"
+)
+
+
+@pytest.fixture
+def mock_vllm_config():
+ """Create a mock VllmConfig for testing."""
+ config = MagicMock()
+ config.model_config.dtype = torch.float16
+ config.model_config.hf_config.architectures = ["LlamaForCausalLM"]
+ config.cache_config.block_size = 16
+ return config
+
+
+@pytest.fixture
+def mock_on_gfx9():
+ """Mock the on_gfx9 function to return True."""
+ with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
+ yield
+
+
+@pytest.mark.parametrize(
+ "env_vars, selected_backend, expected_backend_path",
+ [
+ # Test Case: Explicit FLEX_ATTENTION backend
+ (
+ {},
+ "FLEX_ATTENTION",
+ AttentionBackendEnum.FLEX_ATTENTION.get_path(),
+ ),
+ # Test Case 1: Default (no env vars, no explicit backend)
+ (
+ {},
+ None,
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 2: Explicit TRITON_ATTN backend
+ (
+ {},
+ "TRITON_ATTN",
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 3: Explicit ROCM_ATTN backend
+ (
+ {},
+ "ROCM_ATTN",
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ # Test Case 4: Explicit ROCM_AITER_FA backend
+ (
+ {},
+ "ROCM_AITER_FA",
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend
+ (
+ {},
+ "ROCM_AITER_UNIFIED_ATTN",
+ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
+ ),
+ # Test Case 6: VLLM_ROCM_USE_AITER=1
+ # (defaults to AITER FA when MHA not explicitly disabled)
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
+ (
+ {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"},
+ None,
+ AttentionBackendEnum.ROCM_AITER_FA.get_path(),
+ ),
+ # Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
+ (
+ {
+ "VLLM_ROCM_USE_AITER": "1",
+ "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1",
+ },
+ None,
+ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
+ ),
+ # Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
+ (
+ {"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
+ None,
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ # Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "TRITON_ATTN",
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
+ # (explicitly disabled)
+ (
+ {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
+ None,
+ AttentionBackendEnum.TRITON_ATTN.get_path(),
+ ),
+ # Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "ROCM_ATTN",
+ AttentionBackendEnum.ROCM_ATTN.get_path(),
+ ),
+ ],
+)
+def test_standard_attention_backend_selection(
+ env_vars,
+ selected_backend,
+ expected_backend_path,
+ mock_vllm_config,
+ mock_on_gfx9,
+ monkeypatch,
+):
+ """Test standard attention backend selection with various configurations."""
+ # Set environment variables
+ for key, value in env_vars.items():
+ monkeypatch.setenv(key, value)
+
+ # Import after setting env vars to ensure they're picked up
+ # Reload envs to pick up new environment variables
+ import importlib
+
+ import vllm.envs as envs
+ from vllm.attention.backends.registry import _Backend
+
+ importlib.reload(envs)
+
+ # Convert string backend to enum if provided
+ backend_enum = None
+ if selected_backend:
+ backend_enum = getattr(_Backend, selected_backend)
+
+ # Get the backend class path
+ from vllm.platforms.rocm import RocmPlatform
+
+ backend_path = RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=False,
+ )
+ assert backend_path == expected_backend_path
+
+
+@pytest.mark.parametrize(
+ "env_vars, selected_backend, block_size, expected_backend_path, should_raise",
+ [
+ # Test Case 1: TRITON_MLA with block_size != 1
+ (
+ {},
+ "TRITON_MLA",
+ 16,
+ AttentionBackendEnum.TRITON_MLA.get_path(),
+ False,
+ ),
+ # Test Case 2: TRITON_MLA with block_size == 1 (should raise)
+ (
+ {},
+ "TRITON_MLA",
+ 1,
+ None,
+ True,
+ ),
+ # Test Case 3: ROCM_AITER_MLA with block_size == 1
+ (
+ {},
+ "ROCM_AITER_MLA",
+ 1,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise)
+ (
+ {},
+ "ROCM_AITER_MLA",
+ 16,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ 1,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16
+ # (should use ROCM_AITER_MLA now, as it supports block_size 16)
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ None,
+ 16,
+ AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
+ False,
+ ),
+ # Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA
+ (
+ {"VLLM_ROCM_USE_AITER": "1"},
+ "TRITON_MLA",
+ 16,
+ AttentionBackendEnum.TRITON_MLA.get_path(),
+ False,
+ ),
+ # Test Case 8: Explicit ROCM_AITER_TRITON_MLA
+ (
+ {},
+ "ROCM_AITER_TRITON_MLA",
+ 16,
+ AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path(),
+ False,
+ ),
+ ],
+)
+def test_mla_backend_selection(
+ env_vars,
+ selected_backend,
+ block_size,
+ expected_backend_path,
+ should_raise,
+ mock_vllm_config,
+ monkeypatch,
+):
+ """Test MLA backend selection with various configurations."""
+ # Set environment variables
+ for key, value in env_vars.items():
+ monkeypatch.setenv(key, value)
+
+ # Import after setting env vars
+ # Reload envs
+ import importlib
+
+ import vllm.envs as envs
+ from vllm.attention.backends.registry import _Backend
+
+ importlib.reload(envs)
+
+ # Mock is_aiter_mla_enabled based on env vars and block_size
+ aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1"
+
+ mock_rocm_ops = MagicMock()
+ mock_rocm_ops.is_mla_enabled.return_value = aiter_enabled
+ mock_aiter_module = MagicMock()
+ mock_aiter_module.rocm_aiter_ops = mock_rocm_ops
+
+ with patch.dict("sys.modules", {"vllm._aiter_ops": mock_aiter_module}):
+ # Convert string backend to enum if provided
+ backend_enum = None
+ if selected_backend:
+ backend_enum = getattr(_Backend, selected_backend)
+
+ from vllm.platforms.rocm import RocmPlatform
+
+ if should_raise:
+ with pytest.raises(ValueError):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=block_size,
+ use_mla=True,
+ has_sink=False,
+ use_sparse=False,
+ )
+ else:
+ backend_path = RocmPlatform.get_attn_backend_cls(
+ selected_backend=backend_enum,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=block_size,
+ use_mla=True,
+ has_sink=False,
+ use_sparse=False,
+ )
+ assert backend_path == expected_backend_path
+
+
+def test_aiter_fa_requires_gfx9(mock_vllm_config):
+ """Test that ROCM_AITER_FA requires gfx9 architecture."""
+ from vllm.attention.backends.registry import _Backend
+ from vllm.platforms.rocm import RocmPlatform
+
+ # Mock on_gfx9 to return False
+ with (
+ patch("vllm.platforms.rocm.on_gfx9", return_value=False),
+ pytest.raises(
+ ValueError,
+ match="only supported on gfx9",
+ ),
+ ):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=_Backend.ROCM_AITER_FA,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=False,
+ )
+
+
+def test_sparse_not_supported(mock_vllm_config):
+ """Test that sparse attention is not supported on ROCm."""
+ from vllm.platforms.rocm import RocmPlatform
+
+ with pytest.raises(
+ AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
+ ):
+ RocmPlatform.get_attn_backend_cls(
+ selected_backend=None,
+ head_size=128,
+ dtype=torch.float16,
+ kv_cache_dtype="auto",
+ block_size=16,
+ use_mla=False,
+ has_sink=False,
+ use_sparse=True,
+ )
diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py
index dea89babd4b47..df3d53332c7cd 100644
--- a/tests/v1/attention/utils.py
+++ b/tests/v1/attention/utils.py
@@ -340,4 +340,11 @@ full_cg_backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
),
+ "RocmAttn": BackendConfig(
+ name="RocmAttn",
+ env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
+ comp_config={
+ "cudagraph_mode": "FULL",
+ },
+ ),
}
diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py
index 1d80ee9875913..e0645ed43015e 100644
--- a/tests/v1/core/test_async_scheduler.py
+++ b/tests/v1/core/test_async_scheduler.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque
-import numpy as np
import pytest
from vllm.v1.core.sched.output import SchedulerOutput
@@ -22,7 +21,7 @@ def _make_model_runner_output(
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)},
- sampled_token_ids=[np.array([i]) for i in range(len(req_ids))],
+ sampled_token_ids=[[i] for i in range(len(req_ids))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index 24611a4aaa1b8..12ed59b6e863b 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -1436,6 +1436,65 @@ def test_get_kv_cache_config_one_worker():
],
)
+ # 6 full + 5 sliding, pad to 6 full + 6 sliding. This is a typical case for gpt-oss
+ # eagle where there is only one more full attention layer than sliding window layers
+ kv_cache_specs_hybrid = {
+ "layer_1": new_kv_cache_spec(),
+ "layer_2": new_kv_cache_spec(),
+ "layer_3": new_kv_cache_spec(),
+ "layer_4": new_kv_cache_spec(),
+ "layer_5": new_kv_cache_spec(),
+ "layer_6": new_kv_cache_spec(),
+ "layer_7": new_sliding_window_spec(),
+ "layer_8": new_sliding_window_spec(),
+ "layer_9": new_sliding_window_spec(),
+ "layer_10": new_sliding_window_spec(),
+ "layer_11": new_sliding_window_spec(),
+ }
+
+ kv_cache_config_hybrid = get_kv_cache_configs(
+ vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 6 * 32]
+ )[0]
+ print(kv_cache_config_hybrid)
+ assert kv_cache_config_hybrid == KVCacheConfig(
+ num_blocks=32,
+ kv_cache_tensors=[
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_1", "layer_7"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_2", "layer_8"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_3", "layer_9"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_4", "layer_10"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_5", "layer_11"],
+ ),
+ KVCacheTensor(
+ size=mem_per_block_per_layer * 32,
+ shared_by=["layer_6"],
+ ),
+ ],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer_1", "layer_2", "layer_3", "layer_4", "layer_5", "layer_6"],
+ new_kv_cache_spec(),
+ ),
+ KVCacheGroupSpec(
+ ["layer_7", "layer_8", "layer_9", "layer_10", "layer_11"],
+ new_sliding_window_spec(),
+ ),
+ ],
+ )
# different hidden size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=128),
diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py
index ba0b703302e38..b4805be802723 100644
--- a/tests/v1/core/test_priority_scheduler_random.py
+++ b/tests/v1/core/test_priority_scheduler_random.py
@@ -3,7 +3,6 @@
import random
import uuid
-import numpy as np
import pytest
from vllm.config import VllmConfig
@@ -100,7 +99,8 @@ def _mock_execute_model(
random.randint(*num_output_tokens_range) for _ in range(len(request_ids))
]
sampled_token_ids = [
- np.random.randint(0, 100, size=num_tokens) for num_tokens in num_output_tokens
+ [random.randint(0, 100) for _ in range(num_tokens)]
+ for num_tokens in num_output_tokens
]
return ModelRunnerOutput(
@@ -196,8 +196,6 @@ def test_priority_scheduling_blast(
num_blocks: int,
):
random.seed(42)
- np.random.seed(42)
-
seen_request_prompt_length = dict[str, int]()
seen_request_ids = set[str]()
seen_mm_hashes = set[str]()
diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py
index 0570c0854c678..fe4153e609971 100644
--- a/tests/v1/core/test_scheduler.py
+++ b/tests/v1/core/test_scheduler.py
@@ -3,7 +3,6 @@
import dataclasses
from unittest.mock import Mock
-import numpy as np
import pytest
import torch
@@ -77,11 +76,11 @@ def test_get_num_unfinished_requests():
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
- (None, None),
+ (False, None),
(True, 5),
],
)
-def test_schedule(enable_prefix_caching: bool | None, prompt_logprobs: int | None):
+def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None):
"""Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
"""
@@ -170,7 +169,7 @@ def test_schedule_partial_requests():
req_id_to_index=req_to_index,
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
- sampled_token_ids=[np.array([0]), np.array([]), np.array([])],
+ sampled_token_ids=[[0], [], []],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -217,7 +216,7 @@ def test_no_mm_input_chunking():
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([]) for _ in range(len(requests))],
+ sampled_token_ids=[[] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -277,7 +276,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([]) for _ in range(len(requests))],
+ sampled_token_ids=[[] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -301,8 +300,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([0]), np.array([0])]
- + [np.array([]) for _ in range(len(requests) - 2)],
+ sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -349,8 +347,8 @@ def test_stop_via_update_from_output():
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[
- np.array([EOS_TOKEN_ID]),
- np.array([10, 11]),
+ [EOS_TOKEN_ID],
+ [10, 11],
], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
@@ -394,10 +392,7 @@ def test_stop_via_update_from_output():
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[
- np.array([10, 42, 12]),
- np.array([13, 14]),
- ], # First request hits stop token
+ sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -441,10 +436,7 @@ def test_stop_via_update_from_output():
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[
- np.array([10, 11, 12]),
- np.array([13]),
- ], # First request exceeds max_tokens
+ sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -483,7 +475,7 @@ def test_stop_via_update_from_output():
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
- sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
+ sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -590,12 +582,12 @@ def test_check_stop_min_tokens():
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
- (None, None),
+ (False, None),
(True, 5),
],
)
def test_schedule_concurrent_batches(
- enable_prefix_caching: bool | None, prompt_logprobs: int | None
+ enable_prefix_caching: bool, prompt_logprobs: int | None
):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
@@ -624,7 +616,7 @@ def test_schedule_concurrent_batches(
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -641,7 +633,7 @@ def test_schedule_concurrent_batches(
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -649,6 +641,34 @@ def test_schedule_concurrent_batches(
scheduler.update_from_output(scheduler_output1, model_runner_output)
+@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
+def test_schedule_order(enable_chunked_prefill: bool):
+ scheduler = create_scheduler(
+ max_num_batched_tokens=1024,
+ max_num_seqs=3,
+ enable_chunked_prefill=enable_chunked_prefill,
+ )
+
+ # long requests
+ requests = create_requests(num_requests=2, num_tokens=800)
+ # short requests
+ requests += create_requests(num_requests=2, num_tokens=10)
+
+ for request in requests:
+ scheduler.add_request(request)
+
+ scheduler_output1 = scheduler.schedule()
+
+ if enable_chunked_prefill:
+ # When enable chunked prefill, long requests will be chunked.
+ assert len(scheduler_output1.scheduled_new_reqs) == 2
+ else:
+ # When disable chunked prefill, should not skip the long requests,
+ # and scheduling subsequent short requests in advance,
+ # even though there is still token budgets remaining.
+ assert len(scheduler_output1.scheduled_new_reqs) == 1
+
+
def test_preempt_during_execution():
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
# because block 0 is reserved as the null block.
@@ -678,7 +698,7 @@ def test_preempt_during_execution():
model_runner_output0 = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -695,7 +715,7 @@ def test_preempt_during_execution():
model_runner_output1 = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
- sampled_token_ids=[np.array([42])],
+ sampled_token_ids=[[42]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -712,18 +732,14 @@ def test_preempt_during_execution():
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
- ([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match
- ([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch
- (
- [[1, 2], [3]],
- [np.array([1, 2, 5]), np.array([3, 4])],
- (2, 3, 3, [2, 1]),
- ), # multiple sequences
- ([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence
- ([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence
+ ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
+ ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
+ ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
+ ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
+ ([[]], [[5]], (0, 0, 0, [0])), # empty sequence
(
[[1, 2, 3], [4, 5, 6]],
- [np.array([1, 2, 7]), np.array([4, 8])],
+ [[1, 2, 7], [4, 8]],
(2, 6, 3, [2, 1, 0]),
), # multiple mismatches
],
@@ -757,7 +773,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([0]) for _ in range(len(requests))],
+ sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -984,7 +1000,7 @@ def test_kv_connector_basic(is_async: bool):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([1000])] * len(req_ids),
+ sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1037,7 +1053,7 @@ def test_kv_connector_basic(is_async: bool):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([1000])] * len(req_ids),
+ sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1069,7 +1085,8 @@ def test_kv_connector_basic(is_async: bool):
)
-def test_external_prefix_cache_metrics():
+@pytest.mark.parametrize("is_async", [False, True])
+def test_external_prefix_cache_metrics(is_async: bool):
"""
Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits.
@@ -1079,7 +1096,9 @@ def test_external_prefix_cache_metrics():
NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler(
enable_prefix_caching=False,
- use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
+ use_kv_connector=mock_kv(
+ matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
+ ),
)
# --- Prepare simple requests ---
@@ -1091,16 +1110,22 @@ def test_external_prefix_cache_metrics():
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
)
+ req_ids = []
+ req_to_index = {}
+ for i, request in enumerate(requests):
+ scheduler.add_request(request)
+ req_ids.append(request.request_id)
+ req_to_index[request.request_id] = i
- for req in requests:
- scheduler.add_request(req)
+ if is_async:
+ _step_until_kv_transfer_finished(scheduler, req_ids)
# --- Trigger scheduling and simulate model output ---
output = scheduler.schedule()
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[r.request_id for r in requests],
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
- sampled_token_ids=[np.array([1000])] * NUM_REQUESTS,
+ sampled_token_ids=[[1000]] * NUM_REQUESTS,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1166,7 +1191,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([1000])] * len(req_ids),
+ sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1251,7 +1276,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([1000])] * len(req_ids),
+ sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1344,7 +1369,7 @@ def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
- sampled_token_ids=[np.array([1000])] * len(scheduler.running),
+ sampled_token_ids=[[1000]] * len(scheduler.running),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1428,7 +1453,7 @@ def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
- enable_prefix_caching: bool | None = None,
+ enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
@@ -1447,7 +1472,7 @@ def create_scheduler_with_priority(
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
- (None)
+ (False)
Returns:
{class}`Scheduler` instance with priority scheduling
@@ -1470,17 +1495,12 @@ def create_scheduler_with_priority(
seed=42,
)
# Cache config, optionally force APC
- kwargs_cache = (
- {}
- if enable_prefix_caching is None
- else {"enable_prefix_caching": enable_prefix_caching}
- )
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
- **kwargs_cache,
+ enable_prefix_caching=enable_prefix_caching,
)
kv_transfer_config = (
KVTransferConfig(
@@ -1761,7 +1781,7 @@ def test_priority_scheduling_preemption():
req_id_to_index={
req.request_id: i for i, req in enumerate(low_priority_requests)
},
- sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
+ sampled_token_ids=[[100] for _ in low_priority_requests],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -1830,7 +1850,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
req_id_to_index={
req.request_id: i for i, req in enumerate(low_priority_requests)
},
- sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
+ sampled_token_ids=[[100] for _ in low_priority_requests],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -2076,7 +2096,7 @@ def test_priority_scheduling_heap_property():
model_output = ModelRunnerOutput(
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
- sampled_token_ids=[np.array([100])],
+ sampled_token_ids=[[100]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -2162,7 +2182,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
- sampled_token_ids=[np.array([100])],
+ sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -2193,7 +2213,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[np.array([100]) for _ in requests],
+ sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -2219,7 +2239,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[np.array([]), np.array([100])],
+ sampled_token_ids=[[], [100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -2636,7 +2656,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
model_output = ModelRunnerOutput(
req_ids=[request1.request_id],
req_id_to_index={request1.request_id: 0},
- sampled_token_ids=[np.array([100])],
+ sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -2842,7 +2862,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[np.array([1000])] * len(req_ids),
+ sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -2955,7 +2975,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
- sampled_token_ids=[np.array([100])],
+ sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -3006,7 +3026,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[np.array([100]) for _ in requests],
+ sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -3041,7 +3061,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[np.array([100]), np.array([100, 200])],
+ sampled_token_ids=[[100], [100, 200]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -3227,7 +3247,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
model_output = ModelRunnerOutput(
req_ids=[request1.request_id, request2.request_id],
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
- sampled_token_ids=[np.array([100]), np.array([121])],
+ sampled_token_ids=[[100], [121]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py
index 65511c17473b2..7537c7a60476b 100644
--- a/tests/v1/core/utils.py
+++ b/tests/v1/core/utils.py
@@ -42,7 +42,8 @@ def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
- enable_prefix_caching: bool | None = None,
+ enable_chunked_prefill: bool = True,
+ enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: None | bool | MockKVConfig = None,
@@ -63,7 +64,7 @@ def create_scheduler(
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
- (None)
+ (False)
Returns:
{class}`Scheduler` instance
@@ -76,7 +77,7 @@ def create_scheduler(
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
- enable_chunked_prefill=True,
+ enable_chunked_prefill=enable_chunked_prefill,
async_scheduling=async_scheduling,
)
model_config = ModelConfig(
@@ -87,17 +88,12 @@ def create_scheduler(
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
- kwargs_cache = (
- {}
- if enable_prefix_caching is None
- else {"enable_prefix_caching": enable_prefix_caching}
- )
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
- **kwargs_cache,
+ enable_prefix_caching=enable_prefix_caching,
)
kv_transfer_config = None
if isinstance(use_kv_connector, MockKVConfig):
diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py
index d6bde16eba36b..7f9c2a0571c3c 100644
--- a/tests/v1/cudagraph/test_cudagraph_mode.py
+++ b/tests/v1/cudagraph/test_cudagraph_mode.py
@@ -35,14 +35,22 @@ def temporary_environ(env_vars):
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
-combo_cases_1 = [
- ("FA3", "FULL", True),
- ("FA3", "FULL_AND_PIECEWISE", True),
- ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
- ("FA2", "FULL_AND_PIECEWISE", True),
- ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
- ("FlashInfer", "FULL_AND_PIECEWISE", True),
-]
+if current_platform.is_rocm():
+ combo_cases_1 = [
+ ("RocmAttn", "FULL", True),
+ ("RocmAttn", "FULL_AND_PIECEWISE", True),
+ ("TritonAttn", "FULL", True),
+ ("TritonAttn", "FULL_AND_PIECEWISE", True),
+ ]
+else:
+ combo_cases_1 = [
+ ("FA3", "FULL", True),
+ ("FA3", "FULL_AND_PIECEWISE", True),
+ ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
+ ("FA2", "FULL_AND_PIECEWISE", True),
+ ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
+ ("FlashInfer", "FULL_AND_PIECEWISE", True),
+ ]
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
@@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
-combo_cases_2 = [
- ("FA2", "FULL", CompilationMode.NONE, True),
- ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "PIECEWISE", CompilationMode.NONE, False),
- ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
- ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
- ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
- ("FA2", "NONE", CompilationMode.NONE, True),
- ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
-]
+if current_platform.is_rocm():
+ combo_cases_2 = [
+ ("RocmAttn", "FULL", CompilationMode.NONE, True),
+ ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
+ ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
+ ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
+ ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
+ ("RocmAttn", "NONE", CompilationMode.NONE, True),
+ ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
+ ]
+else:
+ combo_cases_2 = [
+ ("FA2", "FULL", CompilationMode.NONE, True),
+ ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "PIECEWISE", CompilationMode.NONE, False),
+ ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
+ ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
+ ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
+ ("FA2", "NONE", CompilationMode.NONE, True),
+ ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
+ ]
@pytest.mark.parametrize(
diff --git a/tests/v1/determinism/conftest.py b/tests/v1/determinism/conftest.py
index 3c2136e005849..bde02bbd0d5c6 100644
--- a/tests/v1/determinism/conftest.py
+++ b/tests/v1/determinism/conftest.py
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
import pytest
+import vllm.model_executor.layers.batch_invariant as batch_invariant
+
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
+ monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
- yield
diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py
index f018ee551dbfe..b9e2daafb8705 100644
--- a/tests/v1/determinism/test_batch_invariance.py
+++ b/tests/v1/determinism/test_batch_invariance.py
@@ -6,8 +6,15 @@ import random
import pytest
import torch
-from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
+from utils import (
+ BACKENDS,
+ _extract_step_logprobs,
+ _random_prompt,
+ resolve_model_name,
+ skip_unsupported,
+)
+import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams
@@ -15,7 +22,7 @@ from vllm import LLM, SamplingParams
@pytest.mark.timeout(1000)
@pytest.mark.parametrize(
"backend",
- ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
+ BACKENDS,
)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch
@@ -47,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
- model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ model = resolve_model_name(backend)
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
@@ -150,7 +157,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
@skip_unsupported
@pytest.mark.parametrize(
"backend",
- ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
+ BACKENDS,
)
@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@@ -160,7 +167,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
- model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic
@@ -183,6 +190,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
+ gpu_memory_utilization=0.9,
)
# Use more realistic prompts for better token generation
@@ -369,7 +377,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@skip_unsupported
@pytest.mark.parametrize(
"backend",
- ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
+ BACKENDS,
)
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
"""
@@ -377,7 +385,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
Useful for quick smoke testing and debugging.
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
- model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ model = resolve_model_name(backend)
llm = LLM(
model=model,
@@ -419,7 +427,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
@skip_unsupported
@pytest.mark.parametrize(
"backend",
- ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
+ BACKENDS,
)
@pytest.mark.forked
def test_logprobs_without_batch_invariance_should_fail(
@@ -438,10 +446,10 @@ def test_logprobs_without_batch_invariance_should_fail(
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
-
+ monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
- model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}")
@@ -659,7 +667,7 @@ def test_decode_logprobs_match_prefill_logprobs(
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
- model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py
index 23f47863dd23f..d74b435797f8f 100644
--- a/tests/v1/determinism/test_online_batch_invariance.py
+++ b/tests/v1/determinism/test_online_batch_invariance.py
@@ -16,7 +16,8 @@ import sys
from typing import Any
import openai
-from utils import _random_prompt, skip_unsupported
+import pytest
+from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported
from tests.utils import RemoteOpenAIServer
@@ -133,9 +134,14 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported
-def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
+@pytest.mark.parametrize("backend", BACKENDS)
+def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
+ backend: str, monkeypatch: pytest.MonkeyPatch
+) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
- model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ # Override backend for this test (and the RemoteOpenAIServer child process).
+ monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
+ model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = {
diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py
index 5141837faea04..ecbb6a1126933 100644
--- a/tests/v1/determinism/utils.py
+++ b/tests/v1/determinism/utils.py
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
import random
import pytest
import torch
+from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
@@ -12,6 +14,25 @@ skip_unsupported = pytest.mark.skipif(
reason="Requires CUDA and >= Hopper (SM90)",
)
+BACKENDS: list[str] = [
+ "FLASH_ATTN",
+ "FLASHINFER",
+]
+
+if flash_attn_supports_mla():
+ BACKENDS.append("FLASH_ATTN_MLA")
+
+DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
+MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
+
+
+def resolve_model_name(backend: str) -> str:
+ """Resolve the model name for the given backend."""
+ model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
+ if backend.endswith("MLA") and model == DEFAULT_MODEL:
+ return MLA_MODEL
+ return model
+
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
diff --git a/tests/v1/distributed/test_eagle_dp.py b/tests/v1/distributed/test_eagle_dp.py
new file mode 100644
index 0000000000000..9f6a6614fc1fd
--- /dev/null
+++ b/tests/v1/distributed/test_eagle_dp.py
@@ -0,0 +1,77 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import asyncio
+import os
+from contextlib import AsyncExitStack
+from dataclasses import replace
+
+import pytest
+
+from vllm import SamplingParams
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.sampling_params import RequestOutputKind
+from vllm.v1.engine.async_llm import AsyncLLM
+
+DP_SIZE = int(os.getenv("DP_SIZE", 2))
+
+
+@pytest.mark.asyncio
+async def test_run_eagle_dp():
+ target_model = "meta-llama/Llama-3.1-8B-Instruct"
+ draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
+
+ engine_args = AsyncEngineArgs(
+ model=target_model,
+ tokenizer_mode="auto",
+ enforce_eager=False,
+ tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
+ data_parallel_size=DP_SIZE,
+ data_parallel_backend="mp", # ray takes more time
+ trust_remote_code=True,
+ max_model_len=16384,
+ )
+
+ eagle_engine_args = replace(
+ engine_args,
+ speculative_config={
+ "model": draft_model,
+ "method": "eagle",
+ "num_speculative_tokens": 3,
+ },
+ )
+
+ prompt = "This is a test of data parallel with eagle"
+ num_expected_tokens = 100
+ sampling_params = SamplingParams(
+ min_tokens=num_expected_tokens,
+ max_tokens=num_expected_tokens,
+ ignore_eos=True,
+ output_kind=RequestOutputKind.FINAL_ONLY,
+ temperature=0,
+ )
+
+ async def generate_with_timeout(given_engine: AsyncLLM):
+ async for out in given_engine.generate(
+ request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
+ ):
+ token_ids = out.outputs[0].token_ids
+ assert len(token_ids) == num_expected_tokens
+ return token_ids
+
+ async def engine_create_and_generate(engine_args: AsyncEngineArgs):
+ async with AsyncExitStack() as after:
+ engine = AsyncLLM.from_engine_args(engine_args)
+ after.callback(engine.shutdown)
+
+ token_ids = await asyncio.wait_for(
+ generate_with_timeout(engine), timeout=30
+ )
+
+ assert not engine.output_processor.has_unfinished_requests()
+ return token_ids
+
+ token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
+ token_ids_no_eagle = await engine_create_and_generate(engine_args)
+
+ # Test for correctness
+ assert token_ids_with_eagle == token_ids_no_eagle
diff --git a/tests/v1/e2e/test_lora_with_spec_decode.py b/tests/v1/e2e/test_lora_with_spec_decode.py
index 14532f2795443..8c9ab58c3c0ab 100644
--- a/tests/v1/e2e/test_lora_with_spec_decode.py
+++ b/tests/v1/e2e/test_lora_with_spec_decode.py
@@ -61,8 +61,6 @@ def test_batch_inference_correctness(
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
"""
with monkeypatch.context() as m:
- m.setenv("VLLM_USE_V1", "1")
-
# Disable randomness
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
torch.manual_seed(SEED)
diff --git a/tests/v1/ec_connector/integration/README.md b/tests/v1/ec_connector/integration/README.md
index 30426e055ade8..2dbcb307fda32 100644
--- a/tests/v1/ec_connector/integration/README.md
+++ b/tests/v1/ec_connector/integration/README.md
@@ -113,7 +113,7 @@ Quick sanity check:
- Outputs differ between baseline and disagg
- Server startup fails
-- Encoder cache not found (should fallback to local execution)
+- Encoder cache not found (should fall back to local execution)
- Proxy routing errors
## Notes
diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py
index a7d769c8542a9..85f108786c05a 100644
--- a/tests/v1/entrypoints/llm/test_struct_output_generate.py
+++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py
@@ -3,7 +3,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
-from dataclasses import fields
from enum import Enum
from typing import TYPE_CHECKING, Any
@@ -21,7 +20,6 @@ from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import (
- GuidedDecodingParams,
SamplingParams,
StructuredOutputsParams,
)
@@ -46,17 +44,45 @@ EAGLE_SPEC_CONFIG = {
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
- ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
- ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None),
+ # FIXME: Since "auto" will use Mistral tokenizer and these backends do not support
+ # it, we skip these tests for now.
+ # ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
+ # ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None),
+ ("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", None),
+ pytest.param(
+ "mistralai/Ministral-8B-Instruct-2410",
+ "lm-format-enforcer",
+ "hf",
+ None,
+ marks=pytest.mark.skip(
+ reason=(
+ "Flaky: lm-format-enforcer intermittently returns"
+ "incomplete JSON."
+ "See https://github.com/noamgat/lm-format-enforcer/issues/169"
+ )
+ ),
+ ),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
- ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
+ pytest.param(
+ "Qwen/Qwen2.5-1.5B-Instruct",
+ "lm-format-enforcer",
+ "auto",
+ None,
+ marks=pytest.mark.skip(
+ reason=(
+ "Flaky: lm-format-enforcer intermittently returns"
+ "incomplete JSON."
+ "See https://github.com/noamgat/lm-format-enforcer/issues/169"
+ )
+ ),
+ ),
# FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
# ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG),
- ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG),
+ ("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG),
]
@@ -80,24 +106,6 @@ class CarDescription(BaseModel):
car_type: CarType
-def test_guided_decoding_deprecated():
- with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"):
- guided_decoding = GuidedDecodingParams(json_object=True)
-
- structured_outputs = StructuredOutputsParams(json_object=True)
- assert fields(guided_decoding) == fields(structured_outputs)
-
- with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
- sp1 = SamplingParams(guided_decoding=guided_decoding)
-
- with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"):
- sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
-
- assert sp1 == sp2
- assert sp1.structured_outputs == guided_decoding
-
-
-@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE,
@@ -128,6 +136,8 @@ def test_structured_output(
),
seed=120,
tokenizer_mode=tokenizer_mode,
+ load_format="auto" if not model_name.startswith("mistralai/") else "hf",
+ config_format="auto" if not model_name.startswith("mistralai/") else "hf",
speculative_config=speculative_config,
)
@@ -602,7 +612,6 @@ Make the response as short as possible.
)
-@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
[
@@ -687,7 +696,6 @@ def test_structured_output_with_reasoning_matrices(
jsonschema.validate(instance=output_json, schema=reasoning_schema)
-@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, tokenizer_mode", PARAMS_MODELS_TOKENIZER_MODE)
def test_structured_output_auto_mode(
unsupported_json_schema: dict[str, Any],
@@ -699,6 +707,8 @@ def test_structured_output_auto_mode(
max_model_len=1024,
structured_outputs_config=dict(backend="auto"),
tokenizer_mode=tokenizer_mode,
+ load_format="auto",
+ config_format="auto",
)
sampling_params = SamplingParams(
@@ -734,7 +744,6 @@ def test_structured_output_auto_mode(
assert isinstance(parsed_json, dict)
-@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties():
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
@@ -871,13 +880,11 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
output_json = json.loads(generated_text)
-@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
-def test_structured_output_with_structural_tag(
- guided_decoding_backend: str,
-):
+@pytest.mark.parametrize("backend", ["xgrammar"])
+def test_structured_output_with_structural_tag(backend: str):
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
- guided_decoding_backend=guided_decoding_backend,
+ structured_outputs_config=StructuredOutputsConfig(backend=backend),
)
structural_tag_config = {
@@ -895,7 +902,7 @@ def test_structured_output_with_structural_tag(
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
- guided_decoding=StructuredOutputsParams(
+ structured_outputs=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)
),
)
diff --git a/tests/v1/kv_connector/unit/test_lmcache_integration.py b/tests/v1/kv_connector/unit/test_lmcache_integration.py
index 11507d7cd4e7b..33418edc325af 100644
--- a/tests/v1/kv_connector/unit/test_lmcache_integration.py
+++ b/tests/v1/kv_connector/unit/test_lmcache_integration.py
@@ -9,6 +9,12 @@
# Assumption vs. Correctness Tests:
# these unit tests do *not* test correctness of LMCache-side or vLLM-side logic
# it is to ensure that assumptions LMCache makes about vLLM's interface are stable
+
+import pytest
+
+from vllm.platforms import current_platform
+
+
def assumes(obj, attr, is_callable=False, is_instance_of=None):
import inspect
from dataclasses import is_dataclass
@@ -48,6 +54,9 @@ def assumes(obj, attr, is_callable=False, is_instance_of=None):
assert isinstance(attr_value, is_instance_of), assumption_msg
+@pytest.mark.skipif(
+ current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
+)
def test_multimodal_interface():
# protect against interface changes
from vllm.multimodal.inputs import PlaceholderRange
@@ -72,6 +81,9 @@ def test_multimodal_interface():
assert token_ids.tolist() == [0, 0, 0, 0, 4, 4369, 4369, 4369, 4369, 9]
+@pytest.mark.skipif(
+ current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
+)
def test_config_interface():
# protect against interface changes
from vllm.config import VllmConfig
@@ -146,6 +158,9 @@ def test_config_interface():
)
+@pytest.mark.skipif(
+ current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
+)
def test_request_interface():
# protect against interface changes
from types import NoneType
diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py
index 1c1ac915c758e..ffa7d884d2762 100644
--- a/tests/v1/kv_connector/unit/test_multi_connector.py
+++ b/tests/v1/kv_connector/unit/test_multi_connector.py
@@ -20,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
+from vllm.platforms import current_platform
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@@ -69,6 +70,13 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
return True
+@pytest.mark.skipif(
+ current_platform.is_rocm(),
+ reason=(
+ "hipErrorLaunchFailure when running this test, see issue:"
+ "https://github.com/ROCm/pytorch/issues/2822"
+ ),
+)
def test_multi_shared_storage_connector_consistency():
"""
Tests that MultiConnector with two SharedStorageConnectors saves
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py
index b264e5108c16d..b7d7a10057b8b 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector.py
@@ -11,7 +11,6 @@ import uuid
from collections import defaultdict
from unittest.mock import patch
-import numpy as np
import pytest
import ray
import torch
@@ -827,7 +826,7 @@ def test_kv_connector_stats_aggregation():
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
- sampled_token_ids=[np.array([123])], # dummy token
+ sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
@@ -908,7 +907,7 @@ def test_multi_kv_connector_stats_aggregation():
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
- sampled_token_ids=[np.array([123])],
+ sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
@@ -966,7 +965,7 @@ def test_scheduler_kv_connector_stats_aggregation():
model_output = ModelRunnerOutput(
req_ids=["req_0"],
req_id_to_index={"req_0": 0},
- sampled_token_ids=[np.array([123])],
+ sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py
index 23b6c4802d106..69565f584ab89 100644
--- a/tests/v1/kv_connector/unit/test_offloading_connector.py
+++ b/tests/v1/kv_connector/unit/test_offloading_connector.py
@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
+from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
@@ -92,7 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
return self.manager
def get_handlers(
- self, _
+ self, _, __
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
@@ -138,7 +139,10 @@ class RequestRunner:
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
# register worker kv_caches to enable OffloadingWorker creations
- self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)})
+ self.worker_connector.register_cross_layers_kv_cache(
+ kv_cache=torch.empty(0),
+ attn_backend=FlashAttentionBackend,
+ )
# extract connector of scheduler
scheduler_connector = self.scheduler.connector
diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py
index c248104d5b5ea..f35f91bb3adf8 100644
--- a/tests/v1/kv_connector/unit/utils.py
+++ b/tests/v1/kv_connector/unit/utils.py
@@ -7,7 +7,6 @@ from dataclasses import dataclass
from itertools import chain, count
from typing import Any
-import numpy as np
import torch
from vllm import SamplingParams
@@ -229,7 +228,7 @@ def create_model_runner_output(
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else token_id
- sampled_token_ids = [np.array([sampled_token]) for _ in req_ids]
+ sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = (
None
diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py
index 0d4fa344d298c..a248104e16d2d 100644
--- a/tests/v1/kv_offload/test_cpu_gpu.py
+++ b/tests/v1/kv_offload/test_cpu_gpu.py
@@ -103,8 +103,8 @@ def test_transfer(
for i in range(gpu_blocks_per_cpu_block):
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
- # maybe skip a GPU block to test writing to the middle of a CPU block
- if gpu_to_cpu:
+ # maybe skip a GPU block to test reading from the middle of a CPU block
+ if not gpu_to_cpu:
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :]
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
gpu_blocks_per_cpu_block - 1 :
diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py
index b654ea4298dbb..406d4c0b4c1fd 100644
--- a/tests/v1/kv_offload/test_cpu_offloading.py
+++ b/tests/v1/kv_offload/test_cpu_offloading.py
@@ -12,8 +12,14 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
+from vllm.platforms import current_platform
+from vllm.utils.system_utils import set_env_var
-CPU_BLOCK_SIZES = [16, 48]
+CPU_BLOCK_SIZES = [48]
+ATTN_BACKENDS = ["FLASH_ATTN"]
+
+if current_platform.is_cuda():
+ ATTN_BACKENDS.append("FLASHINFER")
class MockSubscriber:
@@ -63,8 +69,88 @@ class MockSubscriber:
self.sub.close()
+def _latency_test(llm: LLM, subscriber: MockSubscriber):
+ sampling_params = SamplingParams(max_tokens=1)
+
+ num_times_cpu_better_than_cold = 0
+ num_tests = 10
+ total_cold_time = 0.0
+ total_gpu_hit_time = 0.0
+ total_cpu_hit_time = 0.0
+ prompt_token_ids = [0] * 10001
+ for i in tqdm(range(num_tests), desc="Running tests"):
+ prompt_token_ids[0] = i
+ prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
+
+ # run generation - this should trigger saving KV cache
+ start_time = time.time()
+ llm.generate(prompts, sampling_params, use_tqdm=False)
+ cold_time = time.time() - start_time
+ total_cold_time += cold_time
+
+ # run generation again - should hit the GPU prefix cache
+ start_time = time.time()
+ llm.generate(prompts, sampling_params, use_tqdm=False)
+ gpu_hit_time = time.time() - start_time
+ total_gpu_hit_time += gpu_hit_time
+
+ # reset prefix cache to avoid GPU hit.
+ llm.reset_prefix_cache()
+
+ assert subscriber.get_new_cpu_stored_events()
+
+ # run generation again - this should trigger loading from CPU
+ start_time = time.time()
+ llm.generate(prompts, sampling_params, use_tqdm=False)
+ cpu_hit_time = time.time() - start_time
+ total_cpu_hit_time += cpu_hit_time
+
+ if cpu_hit_time < cold_time:
+ num_times_cpu_better_than_cold += 1
+
+ print("Average times:")
+ print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
+ print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
+ print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
+
+ assert num_times_cpu_better_than_cold >= 0.8 * num_tests
+
+
+def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
+ sampling_params = SamplingParams(max_tokens=1)
+ cpu_block_size = (
+ llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[
+ "block_size"
+ ]
+ )
+
+ subscriber.get_new_cpu_stored_events()
+
+ # prepend prompt to be cpu block aligned
+ prompt = "Let's count to 10. One, two, three, four,"
+ while (
+ len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size
+ != 0
+ ):
+ prompt = ". " + prompt
+
+ assert subscriber.get_new_cpu_stored_events()
+
+ test_count = 100
+ success_count = 0
+ for i in range(test_count):
+ if (
+ llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text
+ == " five"
+ ):
+ success_count += 1
+
+ assert success_count >= 0.5 * test_count
+
+
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
-def test_cpu_offloading(cpu_block_size: int) -> None:
+@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
+def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""
@@ -92,61 +178,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
topic="test",
)
- llm = LLM(
- model="meta-llama/Llama-3.2-1B-Instruct",
- gpu_memory_utilization=0.5,
- kv_events_config=kv_events_config,
- kv_transfer_config=kv_transfer_config,
- )
-
- sampling_params = SamplingParams(temperature=0, max_tokens=1)
+ with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
+ llm = LLM(
+ model="meta-llama/Llama-3.2-1B-Instruct",
+ gpu_memory_utilization=0.5,
+ kv_events_config=kv_events_config,
+ kv_transfer_config=kv_transfer_config,
+ )
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
try:
- num_times_cpu_better_than_cold = 0
- num_tests = 10
- total_cold_time = 0.0
- total_gpu_hit_time = 0.0
- total_cpu_hit_time = 0.0
- prompt_token_ids = [0] * 10001
- for i in tqdm(range(num_tests), desc="Running tests"):
- prompt_token_ids[0] = i
- prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
-
- # run generation - this should trigger saving KV cache
- start_time = time.time()
- llm.generate(prompts, sampling_params, use_tqdm=False)
- cold_time = time.time() - start_time
- total_cold_time += cold_time
-
- # run generation again - should hit the GPU prefix cache
- start_time = time.time()
- llm.generate(prompts, sampling_params, use_tqdm=False)
- gpu_hit_time = time.time() - start_time
- total_gpu_hit_time += gpu_hit_time
-
- # reset prefix cache to avoid GPU hit.
- llm.reset_prefix_cache()
-
- assert subscriber.get_new_cpu_stored_events()
-
- # run generation again - this should trigger loading from CPU
- start_time = time.time()
- llm.generate(prompts, sampling_params, use_tqdm=False)
- cpu_hit_time = time.time() - start_time
- total_cpu_hit_time += cpu_hit_time
-
- if cpu_hit_time < cold_time:
- num_times_cpu_better_than_cold += 1
-
- print("Average times:")
- print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
- print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
- print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
-
- assert num_times_cpu_better_than_cold >= 0.8 * num_tests
+ _latency_test(llm, subscriber)
+ _accuracy_test(llm, subscriber)
finally:
subscriber.close()
del llm
diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py
index 42584938bc06f..c89c33be80c10 100644
--- a/tests/v1/sample/test_logprobs.py
+++ b/tests/v1/sample/test_logprobs.py
@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
pytest.param(
(
"eagle",
- "meta-llama/Llama-3.1-8B-Instruct",
- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
+ "meta-llama/Llama-3.2-1B-Instruct",
+ "nm-testing/Llama3_2_1B_speculator.eagle3",
),
marks=large_gpu_mark(min_gb=32),
),
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
"""
from vllm import LLM
- prompt = "Hello world"
+ prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
+ # Force prefill chunking
+ enable_chunked_prefill=True,
+ max_num_batched_tokens=32,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
@@ -597,6 +600,84 @@ def test_spec_decode_logprobs(
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
- assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
+ assert math.isclose(
+ ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
+ )
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
+
+
+def test_prompt_logprobs_with_chunking_and_preemption():
+ """Test that prompt logprobs are correctly returned when using
+ both chunked prefill and preemption.
+
+ This test ensures that the num_prompt_logprobs tracking persists
+ across preemptions and prefill chunks.
+ """
+
+ # Create prompts that will trigger chunking and preemption
+ prompts = [
+ "The following numbers of the sequence "
+ + ", ".join(str(i) for i in range(10))
+ + " are:",
+ "In one word, the capital of France is ",
+ ] + [f"Tell me about the number {i}: " for i in range(32)]
+
+ sampling_params = SamplingParams(
+ temperature=0.0,
+ max_tokens=40,
+ min_tokens=20,
+ prompt_logprobs=2, # Request prompt logprobs
+ )
+
+ with VllmRunner(
+ "Qwen/Qwen3-0.6B",
+ max_model_len=512,
+ enable_chunked_prefill=True,
+ max_num_batched_tokens=48, # Force prefill chunking
+ num_gpu_blocks_override=32, # Force preemptions
+ disable_log_stats=False,
+ gpu_memory_utilization=0.25,
+ ) as vllm_model:
+ metrics_before = vllm_model.llm.get_metrics()
+
+ # Generate with prompt logprobs using generate_w_logprobs which
+ # returns (output_ids, output_str, output_logprobs, prompt_logprobs)
+ outputs = vllm_model.generate_w_logprobs(
+ prompts, sampling_params=sampling_params, include_prompt_token_ids=True
+ )
+
+ # Verify that all outputs have prompt logprobs
+ for i, output in enumerate(outputs):
+ _, _, _, prompt_token_ids, prompt_logprobs = output
+ assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
+ f"Output {i} missing prompt logprobs"
+ )
+ assert len(prompt_logprobs) == len(prompt_token_ids), (
+ "Unexpected number of prompt logprob positions"
+ )
+
+ # Each position should have the requested number of logprobs
+ for pos, logprobs_dict in enumerate(prompt_logprobs):
+ if logprobs_dict is not None: # First token may be None
+ assert (
+ sampling_params.prompt_logprobs
+ <= len(logprobs_dict)
+ <= sampling_params.prompt_logprobs + 1
+ ), (
+ f"Output {i} position {pos} has {len(logprobs_dict)} "
+ f"logprobs, expected {sampling_params.prompt_logprobs}"
+ )
+
+ # Check that we actually had preemptions
+ metrics_after = vllm_model.llm.get_metrics()
+ preemptions_before = next(
+ (m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
+ )
+ preemptions_after = next(
+ (m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
+ )
+ preemptions = preemptions_after - preemptions_before
+ assert preemptions > 0, "Test did not trigger any preemptions"
+
+ print(f"Test passed with {preemptions} preemptions")
diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py
index 915b9957031d8..1684252174d3d 100644
--- a/tests/v1/sample/test_sampling_params_e2e.py
+++ b/tests/v1/sample/test_sampling_params_e2e.py
@@ -22,14 +22,6 @@ def test_n_gt_1(llm):
assert len(outputs[0].outputs) == 3
-def test_best_of(llm):
- """Raise a ValueError since best_of is deprecated."""
-
- params = SamplingParams(n=2, best_of=3)
- with pytest.raises(ValueError):
- _ = llm.generate(PROMPT, params)
-
-
def test_penalties(llm):
"""Check that we do not get errors if applied."""
diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py
index 805b8c86b0804..c93c59d1f4c42 100644
--- a/tests/v1/spec_decode/test_eagle.py
+++ b/tests/v1/spec_decode/test_eagle.py
@@ -3,7 +3,6 @@
from unittest import mock
-import numpy as np
import pytest
import torch
@@ -113,9 +112,7 @@ def test_prepare_next_token_ids():
sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device
)
- sampled_token_ids_cpu = [
- np.array([i for i in seq if i != -1]) for seq in sampled_token_ids
- ]
+ sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(
diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py
index 563bc1d957f41..692c39282c372 100644
--- a/tests/v1/spec_decode/test_ngram.py
+++ b/tests/v1/spec_decode/test_ngram.py
@@ -77,7 +77,7 @@ def test_ngram_proposer():
# No match.
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -88,7 +88,7 @@ def test_ngram_proposer():
# No match for 4-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -99,7 +99,7 @@ def test_ngram_proposer():
# No match for 4-gram but match for 3-gram.
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -111,7 +111,7 @@ def test_ngram_proposer():
# In this case, the proposer should return the 4-gram match.
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -122,7 +122,7 @@ def test_ngram_proposer():
# Match for 2-gram and 3-gram, but not 4-gram.
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -133,7 +133,7 @@ def test_ngram_proposer():
# Multiple 3-gram matched, but always pick the first one.
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -144,7 +144,7 @@ def test_ngram_proposer():
# check empty input
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
- sampled_token_ids=[np.array([0])],
+ sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
@@ -157,7 +157,7 @@ def test_ngram_proposer():
# second request has 3 tokens and no match. Padded with -1 for max len 5
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
- sampled_token_ids=[np.array([0]), np.array([1])],
+ sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
@@ -181,7 +181,7 @@ def test_ngram_proposer():
input_2[:3] = [4, 5, 6]
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
- sampled_token_ids=[np.array([0]), np.array([1])],
+ sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py
index 6958d62dc7e90..a4ee53008ce82 100644
--- a/tests/v1/spec_decode/test_tree_attention.py
+++ b/tests/v1/spec_decode/test_tree_attention.py
@@ -3,6 +3,7 @@
import math
+import pytest
import torch
from tests.v1.attention.utils import (
@@ -11,9 +12,16 @@ from tests.v1.attention.utils import (
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
+if not is_flash_attn_varlen_func_available():
+ pytest.skip(
+ "This test requires flash_attn_varlen_func, but it's not available.",
+ allow_module_level=True,
+ )
+
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py
index b95c8df3469b3..d0f1b703fcb92 100644
--- a/tests/v1/worker/test_gpu_model_runner.py
+++ b/tests/v1/worker/test_gpu_model_runner.py
@@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
- supported_kernel_block_sizes = supported_sizes
+ @staticmethod
+ def get_supported_kernel_block_sizes():
+ return supported_sizes
return _MockBackend()
@@ -483,7 +485,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# Permutation that gets you back to expected kv shape
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
- def rnd_stride_order(test_stride=test_stride):
+ def rnd_stride_order(
+ include_num_layers_dimension: bool = False, test_stride=test_stride
+ ):
+ assert not include_num_layers_dimension
return test_stride
# Patch the attention backend class and re-trigger the KV cache creation
@@ -956,7 +961,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
- dcp_kv_cache_interleave_size = 8
+ cp_kv_cache_interleave_size = 8
block_table = BlockTable(
block_size=block_size,
@@ -966,7 +971,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
- dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Verify hybrid block configuration
diff --git a/tests/v1/worker/test_gpu_profiler.py b/tests/v1/worker/test_gpu_profiler.py
new file mode 100644
index 0000000000000..f7255fae05a4e
--- /dev/null
+++ b/tests/v1/worker/test_gpu_profiler.py
@@ -0,0 +1,203 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+
+import vllm.envs as envs
+from vllm.profiler.gpu_profiler import WorkerProfiler
+
+
+class ConcreteWorkerProfiler(WorkerProfiler):
+ """
+ A basic implementation of a worker profiler for testing purposes.
+ """
+
+ def __init__(self):
+ self.start_call_count = 0
+ self.stop_call_count = 0
+ self.should_fail_start = False
+ super().__init__()
+
+ def _start(self) -> None:
+ if self.should_fail_start:
+ raise RuntimeError("Simulated start failure")
+ self.start_call_count += 1
+
+ def _stop(self) -> None:
+ self.stop_call_count += 1
+
+
+@pytest.fixture(autouse=True)
+def reset_mocks():
+ """Fixture to reset mocks and env variables before each test."""
+ envs.VLLM_PROFILER_DELAY_ITERS = 0
+ envs.VLLM_PROFILER_MAX_ITERS = 0
+
+
+def test_immediate_start_stop():
+ """Test standard start without delay."""
+ profiler = ConcreteWorkerProfiler()
+
+ profiler.start()
+ assert profiler._running is True
+ assert profiler._active is True
+ assert profiler.start_call_count == 1
+
+ profiler.stop()
+ assert profiler._running is False
+ assert profiler._active is False
+ assert profiler.stop_call_count == 1
+
+
+def test_delayed_start():
+ """Test that profiler waits for N steps before actually starting."""
+ envs.VLLM_PROFILER_DELAY_ITERS = 2
+ profiler = ConcreteWorkerProfiler()
+
+ # User requests start
+ profiler.start()
+
+ # Should be active (request accepted) but not running (waiting for delay)
+ assert profiler._active is True
+ assert profiler._running is False
+ assert profiler.start_call_count == 0
+
+ # Step 1
+ profiler.step()
+ assert profiler._running is False
+
+ # Step 2 (Threshold reached)
+ profiler.step()
+ assert profiler._running is True
+ assert profiler.start_call_count == 1
+
+
+def test_max_iterations():
+ """Test that profiler stops automatically after max iterations."""
+ envs.VLLM_PROFILER_MAX_ITERS = 2
+ profiler = ConcreteWorkerProfiler()
+
+ profiler.start()
+ assert profiler._running is True
+
+ # Iteration 1
+ profiler.step() # profiling_count becomes 1
+ assert profiler._running is True
+
+ # Iteration 2
+ profiler.step() # profiling_count becomes 2
+ assert profiler._running is True
+
+ # Iteration 3 (Exceeds max)
+ profiler.step() # profiling_count becomes 3
+
+ # Should have stopped now
+ assert profiler._running is False
+ assert profiler.stop_call_count == 1
+
+
+def test_delayed_start_and_max_iters():
+ """Test combined delayed start and max iterations."""
+ envs.VLLM_PROFILER_DELAY_ITERS = 2
+ envs.VLLM_PROFILER_MAX_ITERS = 2
+ profiler = ConcreteWorkerProfiler()
+
+ profiler.start()
+
+ # Step 1
+ profiler.step()
+ assert profiler._running is False
+ assert profiler._active is True
+
+ # Step 2 (Starts now)
+ profiler.step()
+ assert profiler._profiling_for_iters == 1
+ assert profiler._running is True
+ assert profiler._active is True
+
+ # Next iteration
+ profiler.step()
+ assert profiler._profiling_for_iters == 2
+ assert profiler._running is True
+
+ # Iteration 2 (exceeds max)
+ profiler.step()
+
+ # Should have stopped now
+ assert profiler._running is False
+ assert profiler.stop_call_count == 1
+
+
+def test_idempotency():
+ """Test that calling start/stop multiple times doesn't break logic."""
+ profiler = ConcreteWorkerProfiler()
+
+ # Double Start
+ profiler.start()
+ profiler.start()
+ assert profiler.start_call_count == 1 # Should only start once
+
+ # Double Stop
+ profiler.stop()
+ profiler.stop()
+ assert profiler.stop_call_count == 1 # Should only stop once
+
+
+def test_step_inactive():
+ """Test that stepping while inactive does nothing."""
+ envs.VLLM_PROFILER_DELAY_ITERS = 2
+ profiler = ConcreteWorkerProfiler()
+
+ # Not started yet
+ profiler.step()
+ profiler.step()
+
+ # Even though we stepped 2 times, start shouldn't happen because active=False
+ assert profiler.start_call_count == 0
+
+
+def test_start_failure():
+ """Test behavior when the underlying _start method raises exception."""
+ profiler = ConcreteWorkerProfiler()
+ profiler.should_fail_start = True
+
+ profiler.start()
+
+ # Exception caught in _call_start
+ assert profiler._running is False # Should not mark as running
+ assert profiler._active is True # Request is still considered active
+ assert profiler.start_call_count == 0 # Logic failed inside start
+
+
+def test_shutdown():
+ """Test that shutdown calls stop only if running."""
+ profiler = ConcreteWorkerProfiler()
+
+ # Case 1: Not running
+ profiler.shutdown()
+ assert profiler.stop_call_count == 0
+
+ # Case 2: Running
+ profiler.start()
+ profiler.shutdown()
+ assert profiler.stop_call_count == 1
+
+
+def test_mixed_delay_and_stop():
+ """Test manual stop during the delay period."""
+ envs.VLLM_PROFILER_DELAY_ITERS = 5
+ profiler = ConcreteWorkerProfiler()
+
+ profiler.start()
+ profiler.step()
+ profiler.step()
+
+ # User cancels before delay finishes
+ profiler.stop()
+ assert profiler._active is False
+
+ # Further steps should not trigger start
+ profiler.step()
+ profiler.step()
+ profiler.step()
+
+ assert profiler.start_call_count == 0
diff --git a/tests/weight_loading/models-amd.txt b/tests/weight_loading/models-amd.txt
new file mode 100644
index 0000000000000..e31e904c08af4
--- /dev/null
+++ b/tests/weight_loading/models-amd.txt
@@ -0,0 +1,3 @@
+fp8, amd/Meta-Llama-3.1-8B-Instruct-FP8-KV, main
+None, amd/Llama-3.2-1B-Instruct-FP8-KV, main
+fp8, amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV, main
diff --git a/tests/weight_loading/models-large-amd.txt b/tests/weight_loading/models-large-amd.txt
new file mode 100644
index 0000000000000..b6f5b4b16b37f
--- /dev/null
+++ b/tests/weight_loading/models-large-amd.txt
@@ -0,0 +1,3 @@
+fp8, amd/Meta-Llama-3.1-70B-Instruct-FP8-KV, main
+None, microsoft/phi-4, main
+fp8, amd/Mixtral-8x22B-Instruct-v0.1-FP8-KV, main
diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh
index 5ea543f4cb1e8..1cea1bef8dbc9 100755
--- a/tools/ep_kernels/install_python_libraries.sh
+++ b/tools/ep_kernels/install_python_libraries.sh
@@ -1,94 +1,79 @@
#!/usr/bin/env bash
set -ex
-# prepare workspace directory
-WORKSPACE=$1
-if [ -z "$WORKSPACE" ]; then
- export WORKSPACE=$(pwd)/ep_kernels_workspace
-fi
+# usage: ./build.sh [workspace_dir] [mode]
+# mode: "install" (default) → install directly into current Python env
+# "wheel" → build wheels into WORKSPACE/dist
-if [ ! -d "$WORKSPACE" ]; then
- mkdir -p $WORKSPACE
-fi
+WORKSPACE=${1:-$(pwd)/ep_kernels_workspace}
+MODE=${2:-install}
+mkdir -p "$WORKSPACE"
+
+WHEEL_DIR="$WORKSPACE/dist"
+mkdir -p "$WHEEL_DIR"
+NVSHMEM_VER=3.3.9
+
+pushd "$WORKSPACE"
-# configurable pip command (default: pip3)
-PIP_CMD=${PIP_CMD:-pip3}
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
# install dependencies if not installed
-$PIP_CMD install cmake torch ninja
-
-# build nvshmem
-pushd $WORKSPACE
-mkdir -p nvshmem_src
-wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
-tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1
-pushd nvshmem_src
-wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
-git init
-git apply -vvv nvshmem.patch
-
-# assume CUDA_HOME is set correctly
-if [ -z "$CUDA_HOME" ]; then
- echo "CUDA_HOME is not set, please set it to your CUDA installation directory."
- exit 1
+if [ -z "$VIRTUAL_ENV" ]; then
+ uv pip install --system cmake torch ninja
+else
+ uv pip install cmake torch ninja
fi
-# assume TORCH_CUDA_ARCH_LIST is set correctly
-if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then
- echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture."
+# fetch nvshmem
+ARCH=$(uname -m)
+case "${ARCH,,}" in
+ x86_64|amd64)
+ NVSHMEM_SUBDIR="linux-x86_64"
+ NVSHMEM_FILE="libnvshmem-linux-x86_64-${NVSHMEM_VER}_cuda12-archive.tar.xz"
+ ;;
+ aarch64|arm64)
+ NVSHMEM_SUBDIR="linux-sbsa"
+ NVSHMEM_FILE="libnvshmem-linux-sbsa-${NVSHMEM_VER}_cuda12-archive.tar.xz"
+ ;;
+ *)
+ echo "Unsupported architecture: ${ARCH}" >&2
exit 1
-fi
+ ;;
+esac
-# disable all features except IBGDA
-export NVSHMEM_IBGDA_SUPPORT=1
-
-export NVSHMEM_SHMEM_SUPPORT=0
-export NVSHMEM_UCX_SUPPORT=0
-export NVSHMEM_USE_NCCL=0
-export NVSHMEM_PMIX_SUPPORT=0
-export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
-export NVSHMEM_USE_GDRCOPY=0
-export NVSHMEM_IBRC_SUPPORT=0
-export NVSHMEM_BUILD_TESTS=0
-export NVSHMEM_BUILD_EXAMPLES=0
-export NVSHMEM_MPI_SUPPORT=0
-export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
-export NVSHMEM_BUILD_TXZ_PACKAGE=0
-export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
-
-cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install
-cmake --build $WORKSPACE/nvshmem_build/ --target install
+NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}"
+pushd "$WORKSPACE"
+echo "Downloading NVSHMEM ${NVSHMEM_VER} for ${NVSHMEM_SUBDIR} ..."
+curl -fSL "${NVSHMEM_URL}" -o "${NVSHMEM_FILE}"
+tar -xf "${NVSHMEM_FILE}"
+mv "${NVSHMEM_FILE%.tar.xz}" nvshmem
+rm -f "${NVSHMEM_FILE}"
+rm -rf nvshmem/lib/bin nvshmem/lib/share
popd
-export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
+export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem/lib/cmake:$CMAKE_PREFIX_PATH
is_git_dirty() {
local dir=$1
pushd "$dir" > /dev/null
-
- if [ -d ".git" ] && [ -n "$(git status --porcelain 2>/dev/null)" ]; then
+ if [ -d ".git" ] && [ -n "$(git status --porcelain 3>/dev/null)" ]; then
popd > /dev/null
- return 0 # dirty (true)
+ return 0
else
popd > /dev/null
- return 1 # clean (false)
+ return 1
fi
}
-# Function to handle git repository cloning with dirty/incomplete checks
clone_repo() {
local repo_url=$1
local dir_name=$2
local key_file=$3
local commit_hash=$4
-
if [ -d "$dir_name" ]; then
- # Check if directory has uncommitted changes (dirty)
if is_git_dirty "$dir_name"; then
echo "$dir_name directory is dirty, skipping clone"
- # Check if clone failed (directory exists but not a valid git repo or missing key files)
elif [ ! -d "$dir_name/.git" ] || [ ! -f "$dir_name/$key_file" ]; then
echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning"
rm -rf "$dir_name"
@@ -99,7 +84,7 @@ clone_repo() {
cd ..
fi
else
- echo "$dir_name directory exists and appears complete; manually update if needed"
+ echo "$dir_name directory exists and appears complete"
fi
else
git clone "$repo_url"
@@ -111,17 +96,55 @@ clone_repo() {
fi
}
-# build and install pplx, require pytorch installed
-pushd $WORKSPACE
-clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf"
-cd pplx-kernels
-$PIP_CMD install --no-build-isolation -vvv -e .
-popd
+deepep_cuda13_patch() {
+ cuda_version_major=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
+ if [ ${cuda_version_major} -ge 13 ]; then
+ sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
+ fi
+}
-# build and install deepep, require pytorch installed
-pushd $WORKSPACE
-clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "73b6ea4"
-cd DeepEP
-export NVSHMEM_DIR=$WORKSPACE/nvshmem_install
-$PIP_CMD install --no-build-isolation -vvv -e .
-popd
+do_build() {
+ local repo=$1
+ local name=$2
+ local key=$3
+ local commit=$4
+ local extra_env=$5
+
+ pushd "$WORKSPACE"
+ clone_repo "$repo" "$name" "$key" "$commit"
+ cd "$name"
+
+ if [ "$name" == "DeepEP" ]; then
+ deepep_cuda13_patch
+ fi
+
+ if [ "$MODE" = "install" ]; then
+ echo "Installing $name into environment"
+ eval "$extra_env" uv pip install --no-build-isolation -vvv .
+ else
+ echo "Building $name wheel into $WHEEL_DIR"
+ eval "$extra_env" uv build --wheel --no-build-isolation -vvv --out-dir "$WHEEL_DIR" .
+ fi
+ popd
+}
+
+# build pplx-kernels
+do_build \
+ "https://github.com/ppl-ai/pplx-kernels" \
+ "pplx-kernels" \
+ "setup.py" \
+ "12cecfd" \
+ ""
+
+# build DeepEP
+do_build \
+ "https://github.com/deepseek-ai/DeepEP" \
+ "DeepEP" \
+ "setup.py" \
+ "73b6ea4" \
+ "export NVSHMEM_DIR=$WORKSPACE/nvshmem; "
+
+if [ "$MODE" = "wheel" ]; then
+ echo "All wheels written to $WHEEL_DIR"
+ ls -l "$WHEEL_DIR"
+fi
diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh
index 4f2cd302c3eff..ee9a5dd4aa643 100755
--- a/tools/install_deepgemm.sh
+++ b/tools/install_deepgemm.sh
@@ -1,12 +1,13 @@
#!/bin/bash
-# Script to install DeepGEMM from source
-# This script can be used both in Docker builds and by users locally
-
+# Script to build and/or install DeepGEMM from source
+# Default: build and install immediately
+# Optional: build wheels to a directory for later installation (useful in multi-stage builds)
set -e
# Default values
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="594953acce41793ae00a1233eb516044d604bcb6"
+WHEEL_DIR=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
@@ -27,11 +28,20 @@ while [[ $# -gt 0 ]]; do
CUDA_VERSION="$2"
shift 2
;;
+ --wheel-dir)
+ if [[ -z "$2" || "$2" =~ ^- ]]; then
+ echo "Error: --wheel-dir requires a directory path." >&2
+ exit 1
+ fi
+ WHEEL_DIR="$2"
+ shift 2
+ ;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)"
echo " --cuda-version VER CUDA version (auto-detected if not provided)"
+ echo " --wheel-dir PATH If set, build wheel into PATH but do not install"
echo " -h, --help Show this help message"
exit 0
;;
@@ -57,16 +67,15 @@ fi
CUDA_MAJOR="${CUDA_VERSION%%.*}"
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
CUDA_MINOR="${CUDA_MINOR%%.*}"
-
echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)"
# Check CUDA version requirement
if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then
- echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
+ echo "Skipping DeepGEMM build/installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
exit 0
fi
-echo "Installing DeepGEMM from source..."
+echo "Preparing DeepGEMM build..."
echo "Repository: $DEEPGEMM_GIT_REPO"
echo "Reference: $DEEPGEMM_GIT_REF"
@@ -76,23 +85,31 @@ trap 'rm -rf "$INSTALL_DIR"' EXIT
# Clone the repository
git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm"
-
-echo "🏗️ Building DeepGEMM"
pushd "$INSTALL_DIR/deepgemm"
# Checkout the specific reference
git checkout "$DEEPGEMM_GIT_REF"
-# Build DeepGEMM
+# Clean previous build artifacts
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
-rm -rf build dist
-rm -rf *.egg-info
+rm -rf build dist *.egg-info
+
+# Build wheel
+echo "🏗️ Building DeepGEMM wheel..."
python3 setup.py bdist_wheel
-# Install the wheel
+# If --wheel-dir was specified, copy wheels there and exit
+if [ -n "$WHEEL_DIR" ]; then
+ mkdir -p "$WHEEL_DIR"
+ cp dist/*.whl "$WHEEL_DIR"/
+ echo "✅ Wheel built and copied to $WHEEL_DIR"
+ popd
+ exit 0
+fi
+
+# Default behaviour: install built wheel
if command -v uv >/dev/null 2>&1; then
echo "Installing DeepGEMM wheel using uv..."
- # Use --system in Docker contexts, respect user's environment otherwise
if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then
uv pip install --system dist/*.whl
else
@@ -104,5 +121,4 @@ else
fi
popd
-
echo "✅ DeepGEMM installation completed successfully"
diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py
index 8d04848f8f780..34f6e8c928ffb 100755
--- a/tools/pre_commit/mypy.py
+++ b/tools/pre_commit/mypy.py
@@ -38,6 +38,7 @@ FILES = [
"vllm/usage",
"vllm/v1/core",
"vllm/v1/engine",
+ "vllm/v1/worker",
]
# After fixing errors resulting from changing follow_imports
@@ -62,7 +63,6 @@ SEPARATE_GROUPS = [
"vllm/v1/sample",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
- "vllm/v1/worker",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
index e53e4ae6e5296..a8f472d147a0d 100644
--- a/vllm/_aiter_ops.py
+++ b/vllm/_aiter_ops.py
@@ -294,6 +294,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
) -> None:
from aiter.mla import mla_decode_fwd
@@ -308,6 +310,8 @@ def _rocm_aiter_mla_decode_fwd_impl(
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
+ q_scale=q_scale,
+ kv_scale=kv_scale,
)
@@ -322,6 +326,8 @@ def _rocm_aiter_mla_decode_fwd_fake(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
) -> None:
pass
@@ -806,6 +812,8 @@ class rocm_aiter_ops:
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
+ q_scale: torch.Tensor | None = None,
+ kv_scale: torch.Tensor | None = None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
@@ -818,6 +826,8 @@ class rocm_aiter_ops:
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
+ q_scale=q_scale,
+ kv_scale=kv_scale,
)
@staticmethod
@@ -948,6 +958,31 @@ class rocm_aiter_ops:
(8192, 32768),
]
+ @staticmethod
+ def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
+ return (n, k) in [
+ (8192, 4096),
+ (1280, 8192),
+ (16384, 53248),
+ (106496, 16384),
+ (57344, 8192),
+ (8192, 2048),
+ (2560, 8192),
+ (10240, 8192),
+ (16384, 16384),
+ (8192, 28672),
+ (28672, 8192),
+ (18432, 16384),
+ (8192, 1024),
+ (7168, 8192),
+ (5120, 8192),
+ (8192, 8192),
+ (8192, 7168),
+ (14336, 8192),
+ (8192, 14336),
+ (8192, 3584),
+ ]
+
@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 66cf6472eee40..4a1bcc761f994 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -328,10 +328,7 @@ def rotary_embedding(
def rms_norm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None:
- # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
- # If removed, also need to remove contiguous in MatcherRMSNorm
- input_contiguous = input.contiguous()
- torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
+ torch.ops._C.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(
@@ -2204,7 +2201,8 @@ def gather_and_maybe_dequant_cache(
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
- batch_size: int,
+ token_to_seq: torch.Tensor,
+ num_tokens: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: torch.Tensor | None = None,
@@ -2214,7 +2212,8 @@ def gather_and_maybe_dequant_cache(
dst,
block_table,
cu_seq_lens,
- batch_size,
+ token_to_seq,
+ num_tokens,
kv_cache_dtype,
scale,
seq_starts,
diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py
index dd35165d5415e..8b4dc4013362e 100644
--- a/vllm/attention/__init__.py
+++ b/vllm/attention/__init__.py
@@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import (
AttentionType,
)
from vllm.attention.layer import Attention
-from vllm.attention.selector import get_attn_backend
+from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend
__all__ = [
"Attention",
@@ -15,4 +15,5 @@ __all__ = [
"AttentionMetadata",
"AttentionType",
"get_attn_backend",
+ "get_mamba_attn_backend",
]
diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py
index 9275d70fd86a4..bd7e81b15bfc3 100644
--- a/vllm/attention/backends/abstract.py
+++ b/vllm/attention/backends/abstract.py
@@ -46,9 +46,12 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(1)]
+
@staticmethod
@abstractmethod
def get_name() -> str:
@@ -76,7 +79,34 @@ class AttentionBackend(ABC):
raise NotImplementedError
@staticmethod
- def get_kv_cache_stride_order() -> tuple[int, ...]:
+ def get_kv_cache_stride_order(
+ include_num_layers_dimension: bool = False,
+ ) -> tuple[int, ...]:
+ """
+ Get the physical (memory layout) ordering of the kv cache dimensions.
+ e.g. if the KV cache shape is
+ [2, num_blocks, block_size, num_heads, head_size],
+ and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
+ ordering of dimensions is
+ [num_blocks, num_heads, 2, block_size, head_size].
+
+ If this function is unimplemented / raises NotImplementedError,
+ the physical layout of the KV cache will match the logical shape.
+
+ Args:
+ include_num_layers_dimension: if True, includes an additional
+ num_layers dimension, which is assumed to be prepended
+ to the logical KV cache shape.
+ With the above example, a return value (2, 4, 0, 1, 3, 5)
+ corresponds to
+ [num_blocks, num_heads, num_layers, 2, block_size, head_size].
+
+ If an additional dimension is NOT included in the returned
+ tuple, the physical layout will not include a layers dimension.
+
+ Returns:
+ A tuple of ints which is a permutation of range(len(shape)).
+ """
raise NotImplementedError
@classmethod
@@ -115,18 +145,17 @@ class AttentionBackend(ABC):
if block_size not in valid_sizes:
return False
- if not cls.supported_kernel_block_sizes:
+ supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
+ if not supported_kernel_block_sizes:
return True
- for supported_size in cls.supported_kernel_block_sizes:
- is_multiple_of = (
- isinstance(supported_size, MultipleOf)
- and block_size % supported_size.base == 0
- )
- is_int_equal = (
- isinstance(supported_size, int) and block_size == supported_size
- )
- if is_multiple_of or is_int_equal:
+ for supported_size in supported_kernel_block_sizes:
+ if isinstance(supported_size, MultipleOf):
+ supported_size = supported_size.base
+ # With hybrid_blocks feature, the framework-level block size
+ # only needs to be a multiple of the kernel's requirement,
+ # even if the kernel requires a fixed block_size.
+ if block_size % supported_size == 0:
return True
return False
@@ -266,6 +295,12 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int
dcp_rank: int
+ pcp_world_size: int
+ pcp_rank: int
+
+ total_cp_world_size: int
+ total_cp_rank: int
+
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
@@ -278,6 +313,17 @@ class AttentionImpl(ABC, Generic[T]):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
+ try:
+ from vllm.distributed.parallel_state import get_pcp_group
+
+ self.pcp_world_size = get_pcp_group().world_size
+ self.pcp_rank = get_pcp_group().rank_in_group
+ except AssertionError:
+ self.pcp_world_size = 1
+ self.pcp_rank = 0
+ self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
+ self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
+
self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)
diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py
index 0602607966720..e69f1b7ce25e0 100644
--- a/vllm/attention/backends/registry.py
+++ b/vllm/attention/backends/registry.py
@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
-import enum
from collections.abc import Callable
+from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
-class _AttentionBackendEnumMeta(enum.EnumMeta):
+class _AttentionBackendEnumMeta(EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
@@ -23,15 +23,15 @@ class _AttentionBackendEnumMeta(enum.EnumMeta):
try:
return super().__getitem__(name)
except KeyError:
- members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
- valid_backends = ", ".join(m.name for m in members)
+ members = cast("dict[str, Enum]", cls.__members__).keys()
+ valid_backends = ", ".join(members)
raise ValueError(
f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}"
) from None
-class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
+class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends.
The enum value is the default class path, but this can be overridden
@@ -46,12 +46,17 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.flash_sink_attn.FlashSinkAttentionBackend"
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
- XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
+ ROCM_AITER_TRITON_MLA = (
+ "vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
+ )
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
+ ROCM_AITER_MLA_SPARSE = (
+ "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
+ )
TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
@@ -86,7 +91,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
- path = _OVERRIDES.get(self, self.value)
+ path = _ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
@@ -114,18 +119,93 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
Returns:
True if the backend has a registered override
"""
- return self in _OVERRIDES
+ return self in _ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
- _OVERRIDES.pop(self, None)
+ _ATTN_OVERRIDES.pop(self, None)
-_OVERRIDES: dict[AttentionBackendEnum, str] = {}
+class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
+ """Enumeration of all supported mamba attention backends.
+
+ The enum value is the default class path, but this can be overridden
+ at runtime using register_backend().
+
+ To get the actual backend class (respecting overrides), use:
+ backend.get_class()
+ """
+
+ MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
+ MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
+ SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
+ LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
+ GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
+ # Placeholder for third-party/custom backends - must be registered before use
+ CUSTOM = ""
+
+ def get_path(self, include_classname: bool = True) -> str:
+ """Get the class path for this backend (respects overrides).
+
+ Returns:
+ The fully qualified class path string
+
+ Raises:
+ ValueError: If Backend.CUSTOM is used without being registered
+ """
+ path = _MAMBA_ATTN_OVERRIDES.get(self, self.value)
+ if not path:
+ raise ValueError(
+ f"Backend {self.name} must be registered before use. "
+ f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
+ )
+ if not include_classname:
+ path = path.rsplit(".", 1)[0]
+ return path
+
+ def get_class(self) -> "type[AttentionBackend]":
+ """Get the backend class (respects overrides).
+
+ Returns:
+ The backend class
+
+ Raises:
+ ImportError: If the backend class cannot be imported
+ ValueError: If Backend.CUSTOM is used without being registered
+ """
+ return resolve_obj_by_qualname(self.get_path())
+
+ def is_overridden(self) -> bool:
+ """Check if this backend has been overridden.
+
+ Returns:
+ True if the backend has a registered override
+ """
+ return self in _MAMBA_ATTN_OVERRIDES
+
+ def clear_override(self) -> None:
+ """Clear any override for this backend, reverting to the default."""
+ _MAMBA_ATTN_OVERRIDES.pop(self, None)
+
+
+MAMBA_TYPE_TO_BACKEND_MAP = {
+ "mamba1": MambaAttentionBackendEnum.MAMBA1.name,
+ "mamba2": MambaAttentionBackendEnum.MAMBA2.name,
+ "short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
+ "linear_attention": MambaAttentionBackendEnum.LINEAR.name,
+ "gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
+ "custom": MambaAttentionBackendEnum.CUSTOM.name,
+}
+
+
+_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
+_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
def register_backend(
- backend: AttentionBackendEnum, class_path: str | None = None
+ backend: AttentionBackendEnum | MambaAttentionBackendEnum,
+ is_mamba: bool = False,
+ class_path: str | None = None,
) -> Callable[[type], type]:
"""Register or override a backend implementation.
@@ -138,12 +218,17 @@ def register_backend(
Decorator function if class_path is None, otherwise a no-op
Examples:
- # Override an existing backend
+ # Override an existing attention backend
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
- # Register a custom third-party backend
+ # Override an existing mamba attention backend
+ @register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
+ class MyCustomMambaAttn:
+ ...
+
+ # Register a custom third-party attention backend
@register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
@@ -156,11 +241,17 @@ def register_backend(
"""
def decorator(cls: type) -> type:
- _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
+ if is_mamba:
+ _MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
+ else:
+ _ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index]
return cls
if class_path is not None:
- _OVERRIDES[backend] = class_path
+ if is_mamba:
+ _MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
+ else:
+ _ATTN_OVERRIDES[backend] = class_path # type: ignore[index]
return lambda x: x
return decorator
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index 9e540fd437bfb..376101e55e285 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -51,31 +51,6 @@ else:
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
-USE_XFORMERS_OPS = None
-
-
-def check_xformers_availability():
- global USE_XFORMERS_OPS
- if USE_XFORMERS_OPS is not None:
- return USE_XFORMERS_OPS
-
- if current_platform.is_cuda() and current_platform.has_device_capability(100):
- # Xformers FA is not compatible with B200
- USE_XFORMERS_OPS = False
- else:
- try:
- from importlib.util import find_spec
-
- find_spec("xformers.ops")
- USE_XFORMERS_OPS = True
- except ImportError:
- USE_XFORMERS_OPS = False
-
- # the warning only needs to be shown once
- if not USE_XFORMERS_OPS:
- logger.warning("Xformers is not available, falling back.")
-
- return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
@@ -533,7 +508,6 @@ class MultiHeadAttention(nn.Module):
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
@@ -549,12 +523,6 @@ class MultiHeadAttention(nn.Module):
)
)
- if (
- self.attn_backend == AttentionBackendEnum.XFORMERS
- and not check_xformers_availability()
- ):
- self.attn_backend = AttentionBackendEnum.TORCH_SDPA
-
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
@@ -614,12 +582,6 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
-
- out = xops.memory_efficient_attention_forward(
- query, key, value, scale=self.scale
- )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py
index 2cbb5c91cc3b3..af6766bdd1615 100644
--- a/vllm/attention/ops/common.py
+++ b/vllm/attention/ops/common.py
@@ -169,12 +169,11 @@ def correct_attn_out(
return out, lse
-def cp_lse_ag_out_rs(
+def _cp_lse_common(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
- ctx: CPTritonContext = None,
- return_lse=False,
+ ctx: CPTritonContext | None = None,
):
"""
cp_attn_out: [ B, H, D ]
@@ -195,6 +194,21 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
+ return out, lse
+
+
+def cp_lse_ag_out_rs(
+ cp_attn_out: torch.Tensor,
+ cp_attn_lse: torch.Tensor,
+ cp_group: GroupCoordinator,
+ ctx: CPTritonContext | None = None,
+ return_lse: bool = False,
+):
+ """
+ cp_attn_out: [ B, H, D ]
+ cp_attn_lse: [ B, H ]
+ """
+ out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:
@@ -205,6 +219,25 @@ def cp_lse_ag_out_rs(
return out
+def cp_lse_ag_out_ar(
+ cp_attn_out: torch.Tensor,
+ cp_attn_lse: torch.Tensor,
+ cp_group: GroupCoordinator,
+ ctx: CPTritonContext | None = None,
+ return_lse: bool = False,
+):
+ """
+ cp_attn_out: [ B, H, D ]
+ cp_attn_lse: [ B, H ]
+ """
+ out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
+ out = cp_group.all_reduce(out)
+
+ if return_lse:
+ return out, lse
+ return out
+
+
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py
new file mode 100644
index 0000000000000..080e92ecc9408
--- /dev/null
+++ b/vllm/attention/ops/rocm_aiter_mla_sparse.py
@@ -0,0 +1,210 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import importlib
+from functools import lru_cache
+
+import torch
+
+from vllm._aiter_ops import rocm_aiter_ops
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+
+logger = init_logger(__name__)
+
+
+# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
+def fp8_mqa_logits_torch(
+ q: torch.Tensor,
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+) -> torch.Tensor:
+ """Compute FP8 MQA logits for a single sequence without KV paging.
+
+ Args:
+ q: Query tensor of shape [M, H, D]. Casted to
+ `torch.float8_e4m3fn` by caller.
+ kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
+ dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
+ [N, 1]) with dtype `torch.float32`.
+ weights: weights of shape [M, H], dtype `torch.float32`.
+ cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
+ shape [M], dtype int32.
+ cu_seqlen_ke: End indices (exclusive) for valid K per query position,
+ shape [M], dtype int32.
+
+ Returns:
+ Logits tensor of shape [M, N], dtype `torch.float32`.
+ """
+ kv, scale = kv
+ seq_len_kv = kv.shape[0]
+ k = kv.to(torch.bfloat16)
+ q = q.to(torch.bfloat16)
+
+ mask_lo = (
+ torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
+ )
+ mask_hi = (
+ torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
+ )
+ mask = mask_lo & mask_hi
+
+ score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
+ logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
+ logits = logits.masked_fill(~mask, float("-inf"))
+
+ return logits
+
+
+def rocm_fp8_mqa_logits(
+ q: torch.Tensor,
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+) -> torch.Tensor:
+ """Compute FP8 MQA logits for a single sequence without KV paging.
+
+ Args:
+ q: Query tensor of shape [M, H, D]. Casted to
+ `torch.float8_e4m3fn` by caller.
+ kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
+ dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
+ [N, 1]) with dtype `torch.float32`.
+ weights: weights of shape [M, H], dtype `torch.float32`.
+ cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
+ shape [M], dtype int32.
+ cu_seqlen_ke: End indices (exclusive) for valid K per query position,
+ shape [M], dtype int32.
+
+ Returns:
+ Logits tensor of shape [M, N], dtype `torch.float32`.
+ """
+
+ # TODO(ganyi): Temporarily workaround, will remove the module check and reference
+ # path after aiter merge this kernel into main
+ @lru_cache
+ def has_mqa_logits_module():
+ return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
+
+ if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
+ from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
+
+ kv, scale = kv
+ return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
+ else:
+ return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
+
+
+# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
+def fp8_paged_mqa_logits_torch(
+ q: torch.Tensor,
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+):
+ from vllm.utils.math_utils import cdiv
+
+ fp8_dtype = current_platform.fp8_dtype()
+ batch_size, next_n, _, dim = q.size()
+ kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
+ scale = scale.contiguous().view(torch.float)
+ q = q.float()
+ kv_cache = kv_cache.view(fp8_dtype).float() * scale
+ num_block, block_size, _, dim = kv_cache.size()
+ logits = torch.full(
+ [batch_size * next_n, max_model_len],
+ float("-inf"),
+ device=q.device,
+ dtype=torch.float32,
+ )
+ context_lens = context_lens.tolist()
+ for i in range(batch_size):
+ context_len = context_lens[i]
+ q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
+ weight_slice = (
+ weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
+ )
+ for block_rk in range(cdiv(context_len, block_size)):
+ block_idx = block_tables[i][block_rk]
+ qx, kx = q[i], kv_cache[block_idx]
+ k_offsets = torch.arange(
+ block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
+ )
+ mask = (k_offsets[None, :] < context_len) & (
+ k_offsets[None, :] <= q_offsets[:, None]
+ )
+ s = torch.where(
+ mask[None, :, :],
+ (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
+ logits.dtype
+ ),
+ float("-inf"),
+ )
+ s = torch.relu(s) * weight_slice[..., None]
+ s = s.sum(dim=0)
+ logits[
+ i * next_n : (i + 1) * next_n,
+ block_rk * block_size : (block_rk + 1) * block_size,
+ ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
+ return logits
+
+
+def rocm_fp8_paged_mqa_logits(
+ q_fp8: torch.Tensor,
+ kv_cache_fp8: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ schedule_metadata: torch.Tensor,
+ max_model_len: int,
+) -> torch.Tensor:
+ """Compute FP8 MQA logits using paged KV-cache.
+
+ Args:
+ q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
+ `torch.float8_e4m3fn` by caller.
+ kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
+ [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
+ 4 bytes per (block,pos) store the `float` dequant scale.
+ weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
+ context_lens: Tensor of shape [B], dtype int32; effective context length
+ for each batch element.
+ block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
+ block indices to physical blocks in the paged cache.
+ schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
+ used to distribute work across SMs.
+ max_model_len: Maximum sequence length used to size the logits output.
+
+ Returns:
+ Logits tensor of shape [B * next_n, max_model_len], dtype
+ `torch.float32`.
+ """
+
+ if rocm_aiter_ops.is_enabled():
+ from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
+
+ batch_size, next_n, heads, _ = q_fp8.shape
+ out_qk = torch.full(
+ (heads, batch_size * next_n, max_model_len),
+ float("-inf"),
+ device="cuda",
+ dtype=torch.float32,
+ )
+ deepgemm_fp8_paged_mqa_logits_stage1(
+ q_fp8,
+ kv_cache_fp8,
+ weights,
+ out_qk,
+ context_lens,
+ block_tables,
+ max_model_len,
+ )
+ return out_qk.sum(dim=0)
+ else:
+ return fp8_paged_mqa_logits_torch(
+ q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
+ )
diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py
index 06a9f7cd82266..46f8f5117f7a7 100644
--- a/vllm/attention/ops/vit_attn_wrappers.py
+++ b/vllm/attention/ops/vit_attn_wrappers.py
@@ -3,7 +3,7 @@
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
-`to_list` in xformers attn, or `.item()` in flash attention)
+`.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
@@ -19,42 +19,6 @@ import torch.nn.functional as F
from vllm.utils.torch_utils import direct_register_custom_op
-def xformers_attn_seqlens_wrapper(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
- return context_layer
-
-
-def xformers_attn_seqlens_wrapper_fake(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- b, s, h, d = q.shape
- return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
-
-
-direct_register_custom_op(
- op_name="xformers_attn_seqlens_wrapper",
- op_func=xformers_attn_seqlens_wrapper,
- fake_impl=xformers_attn_seqlens_wrapper_fake,
-)
-
-
-def vit_xformers_attn_wrapper(
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
-) -> torch.Tensor:
- return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
-
-
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index 1a092db9ce378..ad19b58aa155c 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -12,7 +12,11 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
-from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.backends.registry import (
+ MAMBA_TYPE_TO_BACKEND_MAP,
+ AttentionBackendEnum,
+ MambaAttentionBackendEnum,
+)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR
@@ -32,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
- return None if backend_name is None else AttentionBackendEnum[backend_name]
+ if backend_name is None:
+ return None
+ if backend_name == "XFORMERS":
+ raise ValueError(
+ "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
+ "details). Please select a supported attention backend."
+ )
+ return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
@@ -197,6 +208,33 @@ def _cached_get_attn_backend(
return backend
+def get_mamba_attn_backend(
+ mamba_type: str,
+) -> type[AttentionBackend]:
+ """Select which mamba attention backend to use and lazily import it."""
+ return _cached_get_mamba_attn_backend(mamba_type)
+
+
+@cache
+def _cached_get_mamba_attn_backend(
+ mamba_type: str,
+) -> type[AttentionBackend]:
+ assert mamba_type and isinstance(mamba_type, str)
+
+ selected_backend = None
+ try:
+ backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
+ selected_backend = MambaAttentionBackendEnum[backend_name]
+ except KeyError as e:
+ raise ValueError(
+ f"Invalid mamba attention backend type: '{backend_name}'. Valid "
+ f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
+ ) from e
+
+ mamba_attn_backend = selected_backend.get_class()
+ return mamba_attn_backend
+
+
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum,
diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py
index 45ac446a7aedf..1298e4acbd87d 100644
--- a/vllm/benchmarks/sweep/serve.py
+++ b/vllm/benchmarks/sweep/serve.py
@@ -211,6 +211,7 @@ def run_combs(
output_dir: Path,
num_runs: int,
dry_run: bool,
+ links: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@@ -226,6 +227,14 @@ def run_combs(
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
+ should_run = all(
+ serve_key in serve_comb
+ and bench_key in bench_comb
+ and serve_comb[serve_key] == bench_comb[bench_key]
+ for serve_key, bench_key in links
+ )
+ if not should_run:
+ continue
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
@@ -262,6 +271,7 @@ class SweepServeArgs:
num_runs: int
dry_run: bool
resume: str | None
+ link_vars: list[tuple[str, str]] | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@@ -285,7 +295,7 @@ class SweepServeArgs:
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
-
+ link_vars = cls.parse_link_vars(args.link_vars)
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -301,6 +311,7 @@ class SweepServeArgs:
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
+ link_vars=link_vars,
)
@classmethod
@@ -376,8 +387,28 @@ class SweepServeArgs:
"parameter combinations for which there are still no output files.",
)
+ parser.add_argument(
+ "--link-vars",
+ type=str,
+ default="",
+ help=(
+ "Comma-separated list of linked variables between serve and bench, "
+ "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
+ ),
+ )
+
return parser
+ @staticmethod
+ def parse_link_vars(s: str) -> list[tuple[str, str]]:
+ if not s:
+ return []
+ pairs = []
+ for item in s.split(","):
+ a, b = item.split("=")
+ pairs.append((a.strip(), b.strip()))
+ return pairs
+
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -397,6 +428,7 @@ def run_main(args: SweepServeArgs):
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
+ links=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py
index 60ef6eef21663..2d8dd4c51c7ef 100644
--- a/vllm/compilation/backends.py
+++ b/vllm/compilation/backends.py
@@ -4,12 +4,14 @@
import ast
import dataclasses
import hashlib
+import json
import operator
import os
import pprint
import time
from collections.abc import Callable, Sequence
from contextlib import contextmanager
+from functools import partial
from typing import Any
import torch
@@ -23,7 +25,9 @@ from vllm.compilation.partition_rules import (
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
+from vllm.config.utils import hash_factors
from vllm.logger import init_logger
+from vllm.logging_utils import lazy
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -59,13 +63,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
else:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
- else:
- assert compilation_config.backend == "eager", (
- "Custom backends not supported with CompilationMode.VLLM_COMPILE"
- )
-
+ elif compilation_config.backend == "eager":
logger.debug("Using EagerAdaptor")
return EagerAdaptor()
+ else:
+ logger.debug("Using custom backend: %s", compilation_config.backend)
+ compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
+ assert isinstance(compiler, CompilerInterface)
+ return compiler
class CompilerManager:
@@ -541,7 +546,10 @@ class VllmBackend:
self.prefix = prefix or model_tag
# Passes to run on the graph post-grad.
- self.post_grad_pass_manager = PostGradPassManager()
+ self.pass_manager = resolve_obj_by_qualname(
+ current_platform.get_pass_manager_cls()
+ )()
+ self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
@@ -558,57 +566,65 @@ class VllmBackend:
def configure_post_pass(self):
config = self.compilation_config
- self.post_grad_pass_manager.configure(self.vllm_config)
+ self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
- PASS_KEY = "post_grad_custom_post_pass"
- if PASS_KEY in inductor_config:
- if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
+ if self.pass_key in inductor_config:
+ if isinstance(inductor_config[self.pass_key], PostGradPassManager):
# PassManager already added to config, make sure it's correct
- assert (
- inductor_config[PASS_KEY].uuid()
- == self.post_grad_pass_manager.uuid()
- )
+ assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
else:
# Config should automatically wrap all inductor passes
- assert isinstance(inductor_config[PASS_KEY], InductorPass)
- self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
- inductor_config[PASS_KEY] = self.post_grad_pass_manager
+ assert isinstance(inductor_config[self.pass_key], InductorPass)
+ self.pass_manager.add(inductor_config[self.pass_key])
+ inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs
) -> VllmSerializableFunction:
- from .caching import _compute_code_hash, compilation_config_hash_factors
-
vllm_config = self.vllm_config
+ # Minimal hashing here with existing utilities, reused below.
+
+ env_factors = envs.compile_factors()
+ env_hash = hash_factors(env_factors)
+ # Compute config/compiler/code hashes once and reuse
+ config_hash = vllm_config.compute_hash()
+ compiler_hash = self.compiler_manager.compute_hash(vllm_config)
+ forward_code_files = list(sorted(self.compilation_config.traced_files))
+
+ logger.debug(
+ "Traced files (to be considered for compilation cache):\n%s",
+ lazy(lambda: "\n".join(forward_code_files)),
+ )
+ hash_content = []
+ for filepath in forward_code_files:
+ hash_content.append(filepath)
+ if filepath == "":
+ # This means the function was dynamically generated, with
+ # e.g. exec(). We can't actually check these.
+ continue
+ try:
+ with open(filepath) as f:
+ hash_content.append(f.read())
+ except Exception:
+ logger.warning("Failed to read file %s", filepath)
+ continue
+ code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
+ # Clear after consumption
+ self.compilation_config.traced_files.clear()
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
-
- factors = compilation_config_hash_factors(vllm_config)
- # 2. factors come from the code files that are traced by Dynamo (
- # it mainly summarizes how the model is used in forward pass)
- code_hash = _compute_code_hash(self.compilation_config.traced_files)
- self.compilation_config.traced_files.clear()
- factors.append(code_hash)
-
- # 3. compiler hash
- compiler_hash = self.compiler_manager.compute_hash(vllm_config)
- factors.append(compiler_hash)
-
- # combine all factors to generate the cache dir
- hash_key = hashlib.md5(
- str(factors).encode(), usedforsecurity=False
- ).hexdigest()[:10]
-
+ factors = [env_hash, config_hash, code_hash, compiler_hash]
+ # Use SHA-256 for cache key hashing to be consistent across
+ # compute_hash functions. Truncate for a short cache dir name.
+ hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
cache_dir = os.path.join(
- envs.VLLM_CACHE_ROOT,
- "torch_compile_cache",
- hash_key,
+ envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
)
self.compilation_config.cache_dir = cache_dir
@@ -621,6 +637,7 @@ class VllmBackend:
os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir
+ # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(
self.compilation_config.inductor_compile_config
)
@@ -638,6 +655,50 @@ class VllmBackend:
local_cache_dir, disable_cache, self.prefix
)
+ # Reuses existing cache key
+
+ logger.debug(
+ "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
+ env_hash,
+ config_hash,
+ compiler_hash,
+ code_hash,
+ local_cache_dir,
+ )
+
+ # Persist and log only hash-relevant factors together.
+ try:
+ logger.debug(
+ "Compile env factors (raw):\n%s\nVllm config hash: %s",
+ lazy(partial(pprint.pformat, env_factors, width=120)),
+ config_hash,
+ )
+ meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
+ if not os.path.exists(meta_path):
+ with open(meta_path, "w") as f:
+ json.dump(
+ {
+ "env": env_factors, # raw factors used for env_hash
+ "config_hash": config_hash,
+ "code_hash": code_hash,
+ "compiler_hash": compiler_hash,
+ },
+ f,
+ indent=2,
+ sort_keys=True,
+ )
+ except Exception:
+ # Best-effort only; metadata write failures are non-fatal.
+ logger.warning(
+ (
+ "Could not write compile cache metadata at %s; continuing without "
+ "metadata. Compiled cache remains valid; diagnostics may be "
+ "limited."
+ ),
+ local_cache_dir,
+ exc_info=True,
+ )
+
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py
index 16e34c2711e9f..6297d9f995aa4 100644
--- a/vllm/compilation/caching.py
+++ b/vllm/compilation/caching.py
@@ -12,6 +12,7 @@ from torch.utils import _pytree as pytree
import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config
+from vllm.config.utils import hash_factors
from vllm.logger import init_logger
try:
@@ -115,7 +116,8 @@ class VllmSerializableFunction(SerializableCallable):
the AOT compiled path.
"""
compile_inputs = [
- inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)
+ inp if inp is not None else example_inputs[i]
+ for i, inp in enumerate(fn.example_inputs)
]
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
@@ -138,7 +140,7 @@ def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
- env_hash = envs.compute_hash()
+ env_hash = hash_factors(envs.compile_factors())
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index 11a18c0e6bb78..6d9da1c488c6d 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -24,6 +24,7 @@ from vllm.config import (
get_current_vllm_config,
set_current_vllm_config,
)
+from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -104,6 +105,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
+ shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -161,6 +163,14 @@ def support_torch_compile(
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
+
+ `shape_invariants` is a function that gets compiled right before forward.
+ The function should have the torch._check calls that are needed to set
+ the relationships between different input sizes. For example:
+ torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
+ This enforces constraints on the symbolic shapes without hardcoding
+ specific values. It is needed for some models to avoid data dependent
+ errors.
"""
def cls_decorator_helper(cls: _T) -> _T:
@@ -199,7 +209,11 @@ def support_torch_compile(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
- cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
+ cls,
+ inferred_dynamic_arg_dims,
+ mark_unbacked_dims,
+ enable_if,
+ shape_invariants,
)
if cls is not None:
@@ -242,6 +256,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
+ shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -276,11 +291,12 @@ def _support_torch_compile(
old_init(self, **kwargs)
self.vllm_config = vllm_config
+ self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
- vllm_config.compilation_config.mode
+ self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
@@ -289,29 +305,38 @@ def _support_torch_compile(
if self.do_not_compile:
return
+ self._check_shape_invariants = shape_invariants
+
compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
- def _mark_dynamic_inputs(mod, *args, **kwargs):
+ def _mark_dynamic_inputs(mod, type, *args, **kwargs):
+ def mark_dynamic(arg, dims):
+ if type == DynamicShapesType.UNBACKED:
+ torch._dynamo.decorators.mark_unbacked(arg, dims)
+ else:
+ torch._dynamo.mark_dynamic(arg, dims)
+
sig = inspect.signature(mod.__class__.forward)
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
+
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
- torch._dynamo.mark_dynamic(arg, dims)
+ mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
- torch._dynamo.mark_dynamic(tensor, dims)
+ mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
@@ -338,6 +363,7 @@ def _support_torch_compile(
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
+ ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
@@ -352,6 +378,14 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
+ # Validate that AOT compile is not used with unbacked dynamic
+ # shapes. aot_compile re-allocates backed symbols post dynamo!
+ if ds_type == DynamicShapesType.UNBACKED:
+ raise ValueError(
+ "AOT compilation is not compatible with UNBACKED dynamic shapes. "
+ "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
+ "when VLLM_USE_AOT_COMPILE is enabled."
+ )
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
@@ -401,7 +435,12 @@ def _support_torch_compile(
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
- _mark_dynamic_inputs(self, *args, **kwargs)
+ _mark_dynamic_inputs(
+ self,
+ ds_type,
+ *args,
+ **kwargs,
+ )
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
@@ -417,9 +456,7 @@ def _support_torch_compile(
# properly when any of these files change.
# 1. the file containing the top-level forward function
- self.vllm_config.compilation_config.traced_files.add(
- original_code_object.co_filename
- )
+ self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
@@ -429,7 +466,7 @@ def _support_torch_compile(
def patched_inline_call(self_):
code = self_.f_code
- self.vllm_config.compilation_config.traced_files.add(code.co_filename)
+ self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -445,12 +482,18 @@ def _support_torch_compile(
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
+ # Prepare backed_size_oblivious config patch if needed
+ fx_config_patches = {}
+ if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
+ fx_config_patches["backed_size_oblivious"] = True
+
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
+ torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:
diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py
index 38eb4e5301a18..e4cd063d2aee1 100644
--- a/vllm/compilation/matcher_utils.py
+++ b/vllm/compilation/matcher_utils.py
@@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
- # TODO: support non-contiguous input for RMSNorm and remove this
- input_contiguous = input.contiguous()
_, result = auto_functionalized(
RMS_OP,
result=result,
- input=input_contiguous,
+ input=input,
weight=weight,
epsilon=self.epsilon,
)
diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py
index 0e8bb2fc97351..fe2547d7fecaf 100644
--- a/vllm/compilation/pass_manager.py
+++ b/vllm/compilation/pass_manager.py
@@ -127,7 +127,7 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
- state = {"pass_config": self.pass_config.uuid(), "passes": []}
+ state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py
index 493e57f97f0f4..b120c85bf232e 100644
--- a/vllm/compilation/wrapper.py
+++ b/vllm/compilation/wrapper.py
@@ -6,6 +6,7 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
+from typing import Any
import torch
import torch._C._dynamo.guards
@@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
+ def check_invariants_and_forward(self, *args, **kwargs):
+ assert hasattr(self, "_check_shape_invariants")
+ self._check_shape_invariants(*args, **kwargs)
+
+ return self.forward(*args, **kwargs)
+
def __init__(self):
self.compiled = False
@@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
+ # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
+ from vllm.compilation.decorators import DynamicShapesType
+
+ ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
+ compiled_ptr: Any = self.forward
+ if ds_type == DynamicShapesType.UNBACKED:
+ if envs.VLLM_USE_BYTECODE_HOOK:
+ # reason is that bytecode does this hack torch._dynamo.eval_frame.
+ # remove_from_cache(self.original_code_object()) to force a new
+ # re-compilation.
+ raise ValueError(
+ "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
+ )
+ compiled_ptr = self.check_invariants_and_forward
+
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
@@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
logger.warning(msg)
self._compiled_callable = torch.compile(
- self.forward,
+ compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
diff --git a/vllm/config/cache.py b/vllm/config/cache.py
index 864cf1be81b20..ef6928d8ebd5c 100644
--- a/vllm/config/cache.py
+++ b/vllm/config/cache.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal
@@ -74,8 +73,8 @@ class CacheConfig:
sliding_window: int | None = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
- enable_prefix_caching: bool | None = None
- """Whether to enable prefix caching. Enabled by default for V1."""
+ enable_prefix_caching: bool = True
+ """Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n
@@ -160,13 +159,29 @@ class CacheConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
- factors: list[Any] = []
- factors.append(self.cache_dtype)
- factors.append(self.mamba_cache_dtype)
- factors.append(self.mamba_ssm_cache_dtype)
- # `cpu_offload_gb` does not use `torch.compile` yet.
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
- return hash_str
+ ignored_factors = {
+ # Runtime/derived knobs that don't affect compiled graph shape
+ "gpu_memory_utilization",
+ "swap_space",
+ "is_attention_free",
+ "num_gpu_blocks_override",
+ "enable_prefix_caching",
+ "prefix_caching_hash_algo",
+ # `cpu_offload_gb` does not use `torch.compile` yet.
+ "cpu_offload_gb",
+ "cpu_kvcache_space_bytes",
+ "mamba_page_size_padded",
+ # Post-init/derived counters
+ "num_gpu_blocks",
+ "num_cpu_blocks",
+ # WIP feature toggle not impacting compiled graph shape
+ "kv_sharing_fast_prefill",
+ }
+
+ from vllm.config.utils import get_hash_factors, hash_factors
+
+ factors = get_hash_factors(self, ignored_factors)
+ return hash_factors(factors)
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
index 088d0b1af757a..556b2d9168b32 100644
--- a/vllm/config/compilation.py
+++ b/vllm/config/compilation.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
-import hashlib
from collections import Counter
from collections.abc import Callable
from dataclasses import asdict, field
@@ -160,7 +159,7 @@ class PassConfig:
current_platform.get_device_capability().to_int(), {}
)
- def uuid(self):
+ def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
@@ -193,6 +192,54 @@ class PassConfig:
self.enable_qk_norm_rope_fusion = False
+class DynamicShapesType(str, enum.Enum):
+ """Types of dynamic shapes handling in torch.compile().
+ see Dynamic shapes and vllm guard dropping in torch_compile.md
+ for more details."""
+
+ BACKED = "backed"
+ """Use backed dynamic shapes. torch.compile() guards on backed dynamic
+ shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
+ without encountering branching on those ranges."""
+
+ UNBACKED = "unbacked"
+ """Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
+ 0/1 specialized, but may throw data dependent errors when branches require
+ their value without explicit unbacked handling."""
+
+ BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
+ """Experimental flag that treats backed symbols as unbacked when explicit
+ unbacked handling is defined."""
+
+
+@config
+@dataclass
+class DynamicShapesConfig:
+ """Configuration to control/debug torch compile dynamic shapes."""
+
+ type: DynamicShapesType = DynamicShapesType.BACKED
+ """Controls the type of dynamic shapes handling to use with torch.compile().
+
+ - BACKED: Default PyTorch behavior with potential guards ignored.
+ - UNBACKED: No guards guaranteed (most sound) but may throw
+ data dependent errors.
+ - BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
+ backed/unbacked.
+ """
+
+ # TODO add a debug mode to fail
+
+ def compute_hash(self) -> str:
+ """
+ Provide a hash for DynamicShapesConfig
+ """
+
+ from vllm.config.utils import get_hash_factors, hash_factors
+
+ factors = get_hash_factors(self, {})
+ return hash_factors(factors)
+
+
@config
@dataclass
class CompilationConfig:
@@ -284,9 +331,9 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
- compilation mode is 3, the backend is used for the piecewise compilation
- (it sees a part of the graph). The backend can not be custom for compilation
- mode 3, i.e. the backend must be either eager or inductor. Furthermore,
+ compilation mode is 3, the backend supports both whole graph and piecewise
+ compilation, available backends include eager, inductor, and custom backends,
+ the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
@@ -323,7 +370,7 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
- Currently, this only works for `Qwen2_5_vl` on selected platforms.
+ Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
# Inductor capture
@@ -349,9 +396,11 @@ class CompilationConfig:
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
+
inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor.
- None: use default configurations."""
+
inductor_passes: dict[str, str] = field(default_factory=dict)
"""Additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
@@ -461,8 +510,15 @@ class CompilationConfig:
max_num_seqs, and prevents capture of many large graphs (>512) that would
greatly increase startup time with limited performance benefit.
"""
+
+ dynamic_shapes_config: DynamicShapesConfig = field(
+ default_factory=DynamicShapesConfig
+ )
+ """Configuration for dynamic shapes options"""
+
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
+
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
@@ -506,28 +562,34 @@ class CompilationConfig:
def compute_hash(self) -> str:
"""
- WARNING: Whenever a new field is added to this config,
- ensure that it is included in the factors list if
- it affects the computation graph.
-
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
- factors: list[Any] = []
- factors.append(self.mode)
- factors.append(self.backend)
- factors.append(self.custom_ops)
- factors.append(self.splitting_ops)
- factors.append(self.use_inductor)
- factors.append(self.use_inductor_graph_partition)
- factors.append(self.inductor_compile_config)
- factors.append(self.inductor_passes)
- factors.append(self.pass_config.uuid())
- factors.append(self.compile_cache_save_format)
- return hashlib.sha256(str(factors).encode()).hexdigest()
+ # Opt-out: default-include declared fields; keep a tiny exclude set;
+ # normalize types; keep SHA-256. For nested opaque configs, include a
+ # stable identifier (e.g., pass_config.compute_hash()) instead of object id.
+
+ ignored_factors = {
+ # Paths/dirs and runtime/metrics that don’t affect compiled graph
+ "debug_dump_path",
+ "cache_dir",
+ "local_cache_dir",
+ "bs_to_padded_graph_size",
+ "traced_files",
+ "compilation_time",
+ "static_forward_context",
+ "pass_config", # handled separately below
+ }
+
+ from vllm.config.utils import get_hash_factors, hash_factors
+
+ factors = get_hash_factors(self, ignored_factors)
+
+ factors["pass_config"] = self.pass_config.compute_hash()
+ return hash_factors(factors)
def __repr__(self) -> str:
exclude = {
@@ -660,6 +722,8 @@ class CompilationConfig:
is_torch_equal_or_newer("2.9.0.dev")
and "combo_kernels" not in self.inductor_compile_config
and "benchmark_combo_kernel" not in self.inductor_compile_config
+ # (fixme @boyuan) combo kernel does not support cpu yet.
+ and not current_platform.is_cpu()
):
# use horizontal fusion, which is useful for fusing qk-norm and
# qk-rope when query and key have different shapes.
@@ -704,7 +768,7 @@ class CompilationConfig:
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "":
- self.backend = current_platform.simple_compile_backend
+ self.backend = current_platform.get_compile_backend()
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
@@ -736,9 +800,7 @@ class CompilationConfig:
assert self.mode == CompilationMode.VLLM_COMPILE
if self.backend not in ["eager", "inductor"]:
- raise ValueError(
- f"Invalid backend for piecewise compilation: {self.backend}"
- )
+ logger.info("Using OOT custom backend for compilation.")
from vllm.compilation.backends import VllmBackend
@@ -917,7 +979,7 @@ class CompilationConfig:
self, uniform_decode_query_len: int, tensor_parallel_size: int
):
multiple_of = uniform_decode_query_len
- if tensor_parallel_size > 1:
+ if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
if (
multiple_of % uniform_decode_query_len != 0
@@ -944,14 +1006,18 @@ class CompilationConfig:
)
)
+ if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size:
+ # if one valid but would be round_down use that
+ rounded_sizes = [multiple_of]
+
if len(rounded_sizes) == 0:
- logger.warning(
- "No valid cudagraph sizes after rounding to multiple of "
- " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
- " or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
- multiple_of,
+ raise ValueError(
+ f"No valid cudagraph sizes after rounding to multiple of {multiple_of} "
+ f"(num_speculative_tokens + 1 or tp if sequence parallelism is enabled)"
+ f" please adjust num_speculative_tokens ({uniform_decode_query_len - 1}"
+ f") or max_cudagraph_capture_size ({self.max_cudagraph_capture_size})"
+ f" or cudagraph_capture_sizes ({self.cudagraph_capture_sizes})"
)
- return
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes
diff --git a/vllm/config/lora.py b/vllm/config/lora.py
index 84e92eef40077..072e0ec2104f5 100644
--- a/vllm/config/lora.py
+++ b/vllm/config/lora.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
-from typing import TYPE_CHECKING, Any, ClassVar, Literal
+from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import ConfigDict, Field, model_validator
@@ -11,7 +11,6 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
-from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -46,19 +45,6 @@ class LoRAConfig:
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
- lora_extra_vocab_size: LoRAExtraVocabSize = Field(
- default=256,
- deprecated=(
- "`lora_extra_vocab_size` is deprecated and will be removed "
- "in v0.12.0. Additional vocabulary support for "
- "LoRA adapters is being phased out."
- ),
- )
- """(Deprecated) Maximum size of extra vocabulary that can be present in a
- LoRA adapter. Will be removed in v0.12.0."""
- lora_vocab_padding_size: ClassVar[int] = (
- current_platform.get_lora_vocab_padding_size()
- )
default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
@@ -87,8 +73,6 @@ class LoRAConfig:
factors.append(self.max_loras)
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
- factors.append(self.lora_extra_vocab_size)
- factors.append(self.lora_vocab_padding_size)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 3e8790a26e0e3..caa9a3440c41d 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
-import json
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
@@ -13,12 +11,13 @@ import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
+from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
-from vllm.config.utils import assert_hashable, config, getattr_iter
+from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
@@ -33,8 +32,8 @@ from vllm.transformers_utils.config import (
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
- uses_custom_attention_masks,
uses_mrope,
+ uses_xdrope_dim,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
@@ -83,7 +82,7 @@ TaskOption = Literal[
"transcription",
"draft",
]
-TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
+TokenizerMode = Literal["auto", "hf", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
@@ -132,7 +131,8 @@ class ModelConfig:
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
"""Tokenizer mode:\n
- - "auto" will use the fast tokenizer if available.\n
+ - "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n
+ - "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "custom" will use --tokenizer to select the preregistered tokenizer."""
@@ -148,9 +148,12 @@ class ModelConfig:
- "bfloat16" for a balance between precision and range.\n
- "float" is shorthand for FP32 precision.\n
- "float32" for FP32 precision."""
- seed: int | None = None
- """Random seed for reproducibility. Initialized to None in V0, but
- initialized to 0 in V1."""
+ seed: int = 0
+ """Random seed for reproducibility.
+
+ We must set the global seed because otherwise,
+ different tensor parallel workers would sample different tokens,
+ leading to inconsistent results."""
hf_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the model."""
hf_text_config: PretrainedConfig = field(init=False)
@@ -240,8 +243,8 @@ class ModelConfig:
first one."""
config_format: str | ConfigFormat = "auto"
"""The format of the model config to load:\n
- - "auto" will try to load the config in hf format if available else it
- will try to load in mistral format.\n
+ - "auto" will try to load the config in hf format if available after trying
+ to load in mistral format.\n
- "hf" will load the config in hf format.\n
- "mistral" will load the config in mistral format."""
hf_token: bool | str | None = None
@@ -324,50 +327,50 @@ class ModelConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
- factors: list[Any] = []
- factors.append(self.model)
- factors.append(self.dtype)
- factors.append(self.quantization)
- factors.append(self.revision)
- factors.append(self.code_revision)
- factors.append(self.max_model_len)
- factors.append(self.max_logprobs)
- factors.append(self.disable_sliding_window)
- factors.append(self.trust_remote_code)
- factors.append(self.generation_config)
- factors.append(self.model_impl)
- factors.append(self.override_generation_config)
- factors.append(self.video_pruning_rate)
- factors.append(self.enable_prompt_embeds)
+ ignored_factors = {
+ "runner",
+ "convert",
+ "task",
+ "tokenizer",
+ "tokenizer_mode",
+ "seed",
+ "hf_config_path",
+ "allowed_local_media_path",
+ "allowed_media_domains",
+ "tokenizer_revision",
+ "spec_target_max_model_len",
+ "enforce_eager",
+ "logprobs_mode",
+ "disable_cascade_attn",
+ "skip_tokenizer_init",
+ "enable_prompt_embeds",
+ "served_model_name",
+ "config_format",
+ "hf_token",
+ "hf_overrides",
+ "logits_processor_pattern",
+ "enable_sleep_mode",
+ "override_attention_dtype",
+ "logits_processors",
+ "io_processor_plugin",
+ "pooler_config",
+ "override_pooler_config",
+ "multimodal_config",
+ "limit_mm_per_prompt",
+ "media_io_kwargs",
+ "mm_processor_kwargs",
+ "mm_processor_cache_gb",
+ "mm_processor_cache_type",
+ "mm_shm_cache_max_object_size_mb",
+ "mm_encoder_tp_mode",
+ "interleave_mm_strings",
+ "skip_mm_profiling",
+ }
- # hf_config can control how the model looks!
- try:
- hf_config_json = self.hf_config.to_json_string(use_diff=False)
- except TypeError:
- from transformers import PretrainedConfig
+ from vllm.config.utils import get_hash_factors, hash_factors
- from vllm.utils.jsontree import json_map_leaves
-
- # Handle nested HF configs with unserializable values gracefully
- hf_config_json = (
- json.dumps(
- json_map_leaves(
- lambda v: v.to_dict()
- if isinstance(v, PretrainedConfig)
- else str(v),
- self.hf_config.to_dict(),
- ),
- indent=2,
- sort_keys=True,
- )
- + "\n"
- )
-
- factors.append(hf_config_json)
-
- str_factors = str(factors)
- assert_hashable(str_factors)
- return hashlib.sha256(str(factors).encode()).hexdigest()
+ factors = get_hash_factors(self, ignored_factors)
+ return hash_factors(factors)
def _update_nested(
self,
@@ -417,7 +420,7 @@ class ModelConfig:
def __post_init__(
self,
# Multimodal config init vars
- limit_mm_per_prompt: dict[str, int] | None,
+ limit_mm_per_prompt: dict[str, int | dict[str, int]] | None,
enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None,
mm_processor_kwargs: dict[str, Any] | None,
@@ -430,23 +433,6 @@ class ModelConfig:
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
) -> None:
- # Set the default seed to 0 in V1.
- # NOTE(woosuk): In V1, we use separate processes for workers (unless
- # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
- # doesn't affect the user process. However, without a consistent seed,
- # different tensor parallel workers would sample different tokens,
- # leading to inconsistent results.
- if self.seed is None:
- self.seed = 0
- if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
- logger.warning(
- "The global random seed is set to %d. Since "
- "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
- "affect the random state of the Python process that "
- "launched vLLM.",
- self.seed,
- )
-
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(
self.model, self.served_model_name
@@ -600,16 +586,26 @@ class ModelConfig:
else: # task == "auto"
pass
else:
- debug_info = {
- "architectures": architectures,
- "is_generative_model": is_generative_model,
- "is_pooling_model": is_pooling_model,
- }
- raise AssertionError(
- "The model should be a generative or "
- "pooling model when task is set to "
- f"{self.task!r}. Found: {debug_info}"
- )
+ # Neither generative nor pooling model - try to convert if possible
+ if is_pooling_task:
+ runner = "pooling"
+ convert = _task_to_convert(self.task)
+ msg_hint = (
+ "Please replace this option with `--runner pooling "
+ f"--convert {convert}` to continue using this model "
+ "as a pooling model."
+ )
+ else:
+ debug_info = {
+ "architectures": architectures,
+ "is_generative_model": is_generative_model,
+ "is_pooling_model": is_pooling_model,
+ }
+ raise AssertionError(
+ "The model should be a generative or "
+ "pooling model when task is set to "
+ f"{self.task!r}. Found: {debug_info}"
+ )
self.runner = runner
self.convert = convert
@@ -1153,12 +1149,6 @@ class ModelConfig:
self,
parallel_config: ParallelConfig,
) -> None:
- if parallel_config.distributed_executor_backend == "external_launcher":
- assert self.seed is not None, (
- "Seed must be set when using external launcher backend to "
- "make sure sampling results are the same across workers."
- )
-
total_num_attention_heads = getattr(
self.hf_text_config, "num_attention_heads", 0
)
@@ -1369,11 +1359,7 @@ class ModelConfig:
# Coerce to 0 if explicitly set to None
return num_experts or 0
- def get_layers_start_end_indices(
- self, parallel_config: ParallelConfig
- ) -> tuple[int, int]:
- from vllm.distributed.utils import get_pp_indices
-
+ def get_total_num_hidden_layers(self) -> int:
if (
self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
@@ -1393,6 +1379,15 @@ class ModelConfig:
total_num_hidden_layers = getattr(
self.hf_text_config, "num_hidden_layers", 0
)
+ return total_num_hidden_layers
+
+ def get_layers_start_end_indices(
+ self, parallel_config: ParallelConfig
+ ) -> tuple[int, int]:
+ from vllm.distributed.utils import get_pp_indices
+
+ total_num_hidden_layers = self.get_total_num_hidden_layers()
+
# the layout order is: DP x PP x TP
pp_rank = (
parallel_config.rank // parallel_config.tensor_parallel_size
@@ -1622,8 +1617,8 @@ class ModelConfig:
return uses_mrope(self.hf_config)
@property
- def uses_custom_attention_masks(self) -> bool:
- return uses_custom_attention_masks(self.hf_config)
+ def uses_xdrope_dim(self) -> int:
+ return uses_xdrope_dim(self.hf_config)
@property
def is_multimodal_model(self) -> bool:
@@ -2097,31 +2092,32 @@ def _get_and_verify_max_len(
)
derived_max_model_len = default_max_len
- rope_scaling = getattr(hf_config, "rope_scaling", None)
+ # In Transformers v5 rope_parameters could be TypedDict or dict[str, TypedDict].
+ # To simplify the verification, we convert it to dict[str, TypedDict].
+ rope_parameters = getattr(hf_config, "rope_parameters", None)
+ if rope_parameters and not set(rope_parameters.keys()).issubset(
+ ALLOWED_LAYER_TYPES
+ ):
+ rope_parameters = {"": rope_parameters}
+
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
# scaling, so we skip applying the scaling factor again.
- if rope_scaling is not None and "gemma3" not in hf_config.model_type:
- # No need to consider "type" key because of patch_rope_scaling when
- # loading HF config
- rope_type = rope_scaling["rope_type"]
+ if rope_parameters is not None and "gemma3" not in hf_config.model_type:
+ scaling_factor = 1.0
+ for rp in rope_parameters.values():
+ # No need to consider "type" key because of patch_rope_parameters when
+ # loading HF config
+ rope_type = rp["rope_type"]
- if rope_type not in ("su", "longrope", "llama3"):
- if disable_sliding_window:
- # TODO(robertgshaw): Find a model that supports rope_scaling
- # with sliding window to see if this case should be allowed.
- raise NotImplementedError(
- "Disabling sliding window is not supported for models "
- "with rope_scaling. Please raise an issue so we can "
- "investigate."
- )
+ if rope_type not in ("su", "longrope", "llama3"):
+ # NOTE: rope_type == "default" does not define factor https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
+ # NOTE: This assumes all layer types have the same scaling factor.
+ scaling_factor = rp.get("factor", scaling_factor)
- # NOTE: rope_type == "default" does not define factor
- # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
- scaling_factor = rope_scaling.get("factor", 1.0)
-
- if rope_type == "yarn":
- derived_max_model_len = rope_scaling["original_max_position_embeddings"]
- derived_max_model_len *= scaling_factor
+ if rope_type == "yarn":
+ derived_max_model_len = rp["original_max_position_embeddings"]
+ # Do this outside loop since all layer types should have the same scaling
+ derived_max_model_len *= scaling_factor
if encoder_config and "max_seq_length" in encoder_config:
derived_max_model_len = encoder_config["max_seq_length"]
@@ -2131,7 +2127,9 @@ def _get_and_verify_max_len(
if max_model_len is None:
# For LongRoPE, default to original_max_position_embeddings to avoid
# performance degradation for shorter sequences
- if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
+ if rope_parameters is not None and any(
+ rp["rope_type"] == "longrope" for rp in rope_parameters.values()
+ ):
max_model_len = int(
getattr(
hf_config, "original_max_position_embeddings", derived_max_model_len
@@ -2148,16 +2146,7 @@ def _get_and_verify_max_len(
# that will be bigger than derived_max_model_len. We compare user input
# with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None)
- if model_max_length is not None and max_model_len <= model_max_length:
- if disable_sliding_window:
- # TODO(robertgshaw): Find a model that has model_max_length
- # with sliding window to see if this case should be allowed.
- raise NotImplementedError(
- "Disabling sliding window is not supported for models "
- "model_max_length in the config. Please raise an issue "
- "so we can investigate."
- )
- else:
+ if model_max_length is None or max_model_len > model_max_length:
msg = (
f"User-specified max_model_len ({max_model_len}) is greater "
f"than the derived max_model_len ({max_len_key}="
diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py
index 9f62b35ed515c..00a81a319bf72 100644
--- a/vllm/config/multimodal.py
+++ b/vllm/config/multimodal.py
@@ -173,6 +173,12 @@ class MultiModalConfig:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
+ if isinstance(value, str) and value.upper() == "XFORMERS":
+ raise ValueError(
+ "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
+ "details). Please select a supported attention backend."
+ )
+
if value is None or isinstance(value, AttentionBackendEnum):
return value
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index 9a6326d62e82e..913e97250d3d3 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import hashlib
import os
from typing import TYPE_CHECKING, Any, Literal
@@ -61,6 +60,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
+ use_async: bool = False
+ """
+ Whether to use non-blocking EPLB.
+ """
@config
@@ -72,6 +75,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
+ prefill_context_parallel_size: int = 1
+ """Number of prefill context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
@@ -136,22 +141,6 @@ class ParallelConfig:
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- num_redundant_experts: int | None = None
- """`num_redundant_experts` is deprecated and has been replaced with
- `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
- Please use `eplb_config.num_redundant_experts` instead."""
- eplb_window_size: int | None = None
- """`eplb_window_size` is deprecated and has been replaced with
- `eplb_config.window_size`. This will be removed in v0.12.0.
- Please use `eplb_config.window_size` instead."""
- eplb_step_interval: int | None = None
- """`eplb_step_interval` is deprecated and has been replaced with
- `eplb_config.step_interval`. This will be removed in v0.12.0.
- Please use `eplb_config.step_interval` instead."""
- eplb_log_balancedness: bool | None = None
- """`eplb_log_balancedness` is deprecated and has been replaced with
- `eplb_config.log_balancedness`. This will be removed in v0.12.0.
- Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model
@@ -240,14 +229,25 @@ class ParallelConfig:
needs to be divisible by dcp_size."""
dcp_kv_cache_interleave_size: int = 1
- """Interleave size of kv_cache storage while using dcp or cp > 1,
- store interleave_size tokens on (d)cp i,
- then store next interleave_size tokens on (d)cp i+1.
- Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
- Interleave_size=block_size: block-level align, first fill the block on first rank,
- token is stored on rank i+1 block j after rank i block j is full.
- Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
- Block_size should be divisible by dcp_kv_cache_interleave_size.
+ """
+ Interleave size of kv_cache storage while using DCP.
+ dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
+ and will be deprecated when PCP is fully supported.
+
+ """
+ cp_kv_cache_interleave_size: int = 1
+ """Interleave size of kv_cache storage while using DCP or PCP.
+ For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
+ and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
+ store interleave_size tokens on total_cp_rank i,
+ then store next interleave_size tokens on taotal_cp_rank i+1.
+ Interleave_size=1: token-level alignment, where token `i` is stored on
+ total_cp_rank `i % total_cp_world_size`.
+ Interleave_size=block_size: block-level alignment, where tokens are
+ first populated to the preceding ranks. Tokens are then stored
+ in (rank i+1, block j) only after (rank i, block j) is fully occupied.
+ Block_size should be greater than or equal to cp_kv_cache_interleave_size.
+ Block_size should be divisible by cp_kv_cache_interleave_size.
"""
_api_process_count: int = Field(default=1, gt=0)
@@ -312,6 +312,11 @@ class ParallelConfig:
"num_redundant_experts."
)
+ if self.prefill_context_parallel_size > 1:
+ raise ValueError(
+ "Prefill context parallelism is not fully supported. "
+ "Please set prefill_context_parallel_size to 1."
+ )
return self
@property
@@ -448,19 +453,41 @@ class ParallelConfig:
This hash is also used for DP worker configuration validation
to prevent hangs from mismatched collective communication patterns.
"""
- factors: list[Any] = []
- factors.append(self.pipeline_parallel_size)
- factors.append(self.tensor_parallel_size)
- factors.append(self.enable_expert_parallel)
- factors.append(self.data_parallel_size)
- factors.append(self.all2all_backend)
- factors.append(self.enable_eplb)
- if self.enable_eplb:
- factors.append(self.eplb_config.log_balancedness)
- factors.append(self.eplb_config.window_size)
- factors.append(self.eplb_config.step_interval)
- factors.append(self.eplb_config.num_redundant_experts)
- return hashlib.sha256(str(factors).encode()).hexdigest()
+ ignored_factors = {
+ # Derived/runtime topology, networking, or launch details
+ "data_parallel_rank",
+ "data_parallel_rank_local",
+ "data_parallel_backend",
+ "data_parallel_external_lb",
+ "data_parallel_hybrid_lb",
+ "data_parallel_master_ip",
+ "data_parallel_master_port",
+ "_data_parallel_master_port_list",
+ "data_parallel_rpc_port",
+ "rank",
+ "master_addr",
+ "master_port",
+ "node_rank",
+ "nnodes",
+ "max_parallel_loading_workers",
+ "disable_custom_all_reduce",
+ "ray_workers_use_nsight",
+ "ray_runtime_env",
+ "placement_group",
+ "distributed_executor_backend",
+ "worker_cls",
+ "sd_worker_cls",
+ "worker_extension_cls",
+ "_api_process_count",
+ "_api_process_rank",
+ }
+
+ from vllm.config.utils import get_hash_factors, hash_factors
+
+ factors = get_hash_factors(self, ignored_factors)
+ # Explicitly include backend affecting env factor as before
+ factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
+ return hash_factors(factors)
def __post_init__(self) -> None:
# Set all2all_backend from env var if not specified, with deprecation warning
@@ -473,42 +500,12 @@ class ParallelConfig:
"--all2all-backend command-line argument instead."
)
- # Forward deprecated fields to their new location
- if self.num_redundant_experts is not None:
- self.eplb_config.num_redundant_experts = self.num_redundant_experts
- logger.warning_once(
- "num_redundant_experts is deprecated and has been replaced "
- "with eplb_config.num_redundant_experts. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_window_size is not None:
- self.eplb_config.window_size = self.eplb_window_size
- logger.warning_once(
- "eplb_window_size is deprecated and has been replaced "
- "with eplb_config.window_size. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_step_interval is not None:
- self.eplb_config.step_interval = self.eplb_step_interval
- logger.warning_once(
- "eplb_step_interval is deprecated and has been replaced "
- "with eplb_config.step_interval. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
- if self.eplb_log_balancedness is not None:
- self.eplb_config.log_balancedness = self.eplb_log_balancedness
- logger.warning_once(
- "eplb_log_balancedness is deprecated and has been replaced "
- "with eplb_config.log_balancedness. This will be removed "
- "in v0.12.0. Changing this field after initialization will "
- "have no effect."
- )
-
# Continue with the rest of the initialization
- self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
+ self.world_size = (
+ self.pipeline_parallel_size
+ * self.tensor_parallel_size
+ * self.prefill_context_parallel_size
+ )
if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py
index 8194295ffedb6..b6078706daacf 100644
--- a/vllm/config/scheduler.py
+++ b/vllm/config/scheduler.py
@@ -62,15 +62,6 @@ class SchedulerConfig:
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
- num_lookahead_slots: int = Field(default=0, ge=0)
- """The number of slots to allocate per sequence per
- step, beyond the known token ids. This is used in speculative
- decoding to store KV activations of tokens which may or may not be
- accepted.
-
- NOTE: This will be replaced by speculative config in the future; it is
- present to enable correctness tests until then."""
-
enable_chunked_prefill: bool = True
"""If True, prefill requests can be chunked based
on the remaining `max_num_batched_tokens`.
diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py
index 13a8632413d91..d7c019c73d598 100644
--- a/vllm/config/speculative.py
+++ b/vllm/config/speculative.py
@@ -9,6 +9,7 @@ from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
+from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
@@ -18,10 +19,8 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig
import vllm.model_executor.layers.quantization as me_quant
- from vllm.config import ModelConfig
else:
PretrainedConfig = Any
- ModelConfig = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
@@ -316,10 +315,6 @@ class SpeculativeConfig:
self.prompt_lookup_min = 0
if self.model is not None:
- # TODO: Move this import to the top once `ModelConfig`
- # lives in `vllm.config.model`.
- from vllm.config import ModelConfig
-
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
@@ -634,16 +629,6 @@ class SpeculativeConfig:
return self
- @property
- def num_lookahead_slots(self) -> int:
- """The number of additional slots the scheduler should allocate per
- step, in addition to the slots allocated for each known token.
-
- This is equal to the number of speculative tokens, as each speculative
- token must be scored.
- """
- return self.num_speculative_tokens
-
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
diff --git a/vllm/config/utils.py b/vllm/config/utils.py
index 7e0878d96bbd6..02f2b75f608f1 100644
--- a/vllm/config/utils.py
+++ b/vllm/config/utils.py
@@ -3,14 +3,19 @@
"""Utility functions for vLLM config dataclasses."""
import ast
+import enum
+import hashlib
import inspect
+import json
+import pathlib
import textwrap
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
import regex as re
+import torch
from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable
@@ -176,3 +181,115 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
+
+
+def normalize_value(x):
+ """Return a stable, JSON-serializable canonical form for hashing.
+ Order: primitives, special types (Enum, callable, torch.dtype, Path), then
+ generic containers (Mapping/Set/Sequence) with recursion.
+ """
+ # Fast path
+ if x is None or isinstance(x, (bool, int, float, str)):
+ return x
+
+ # Enums: tag with FQN to avoid primitive collisions.
+ # Ex: Enum(1) vs int(1) -> ("module.QualName", value).
+ if isinstance(x, enum.Enum):
+ enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
+ return (enum_type, normalize_value(x.value))
+
+ # Classes (types) are accepted and canonicalized by their fully-qualified
+ # name (module.qualname) for a stable identifier.
+ # Instances are only accepted if they expose uuid(); otherwise they are
+ # rejected to avoid under-hashing object state.
+
+ # Callables: accept classes only; reject funcs/lambdas/methods.
+ # Used by LogitsProcessor types and ModelConfig.hf_overrides.
+ if isinstance(x, type):
+ module = getattr(x, "__module__", "")
+ qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
+ return ".".join([p for p in (module, qual) if p]) or repr(x)
+
+ # Prefer stable uuid identifiers for objects that provide them, even if
+ # they are callable instances (e.g., InductorPass wrappers).
+ if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
+ return x.uuid()
+
+ if callable(x):
+ raise TypeError("normalize_value: function or callable instance unsupported")
+
+ # Torch dtype: stringify (torch.float64 -> "torch.float64").
+ # We rely on the string form here; dtype-bearing fields that need additional
+ # disambiguation should encode that at the config layer.
+ if isinstance(x, torch.dtype):
+ return str(x)
+
+ # Bytes
+ if isinstance(x, (bytes, bytearray)):
+ return x.hex()
+
+ # Paths (canonicalize)
+ if isinstance(x, pathlib.Path):
+ try:
+ return str(x.expanduser().resolve())
+ except Exception:
+ return str(x)
+
+ # Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
+ if is_dataclass(x):
+ type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
+ items = tuple(
+ (f.name, normalize_value(getattr(x, f.name)))
+ for f in sorted(fields(x), key=lambda f: f.name)
+ )
+ return (type_fqn, items)
+
+ # Containers (generic)
+ if isinstance(x, Mapping):
+ return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
+ if isinstance(x, Set):
+ return tuple(sorted(repr(normalize_value(v)) for v in x))
+ if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
+ return tuple(normalize_value(v) for v in x)
+
+ # PretrainedConfig
+ if hasattr(x, "to_json_string") and callable(x.to_json_string):
+ return x.to_json_string()
+
+ # Unsupported type: e.g., modules, generators, open files, or objects
+ # without a stable JSON/UUID representation. Hard-error to avoid
+ # under-hashing.
+ # If you hit this, either reshape your config to use supported primitives
+ # and containers, or extend normalize_value to provide a stable encoding
+ # (e.g., via uuid() or to_json_string()) for this type.
+ raise TypeError(
+ f"normalize_value: unsupported type '{type(x).__name__}'. "
+ "Ensure config values use supported primitives/containers or add a "
+ "stable representation for this type."
+ )
+
+
+def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
+ """Gets the factors used for hashing a config class.
+ - Includes all dataclass fields not in `ignored_factors`.
+ - Errors on non-normalizable values.
+ """
+ factors: dict[str, object] = {}
+ for dc_field in fields(config):
+ factor = dc_field.name
+ if factor in ignored_factors:
+ continue
+ value = getattr(config, factor, None)
+ try:
+ factors[factor] = normalize_value(value)
+ except TypeError as e:
+ raise TypeError(
+ f"get_hash_factors: unsupported type for key '{factor}' "
+ f"({type(value).__name__})"
+ ) from e
+ return factors
+
+
+def hash_factors(items: dict[str, object]) -> str:
+ """Return a SHA-256 hex digest of the canonical items structure."""
+ return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index 672b004c4aa56..8a3599416bc72 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -96,7 +96,7 @@ class VllmConfig:
"""`torch.compile` and cudagraph capture configuration for the model.
As a shorthand, one can append compilation arguments via
- -0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`).
+ -0.parameter=argument such as `-O.mode=3` (same as `-O='{"mode":3}'`).
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
@@ -481,6 +481,14 @@ class VllmConfig:
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
+ # prefill context parallel do not support full cudagraphs
+ elif self.parallel_config.prefill_context_parallel_size > 1:
+ logger.warning_once(
+ "Prefill context parallel (PCP) is enabled, which is "
+ "incompatible with full CUDA graphs. "
+ "Overriding cudagraph_mode to PIECEWISE."
+ )
+ self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
@@ -610,22 +618,34 @@ class VllmConfig:
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
+ if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
+ self.parallel_config.cp_kv_cache_interleave_size
+ != self.parallel_config.dcp_kv_cache_interleave_size
+ ):
+ self.parallel_config.cp_kv_cache_interleave_size = (
+ self.parallel_config.dcp_kv_cache_interleave_size
+ )
+ logger.warning_once(
+ "cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
+ "_interleave_size. And dcp-kv-cache-interleave-size will be "
+ "deprecated when PCP is fully supported."
+ )
assert (
- self.parallel_config.dcp_kv_cache_interleave_size
+ self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
- % self.parallel_config.dcp_kv_cache_interleave_size
+ % self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
- "than or equal to and divisible by dcp_kv_cache_interleave_size "
- f"({self.parallel_config.dcp_kv_cache_interleave_size})."
+ "than or equal to and divisible by cp_kv_cache_interleave_size "
+ f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
assert (
- self.parallel_config.dcp_kv_cache_interleave_size == 1
+ self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
- ), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
+ ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py
index eb1f173b11925..7a049b003cf73 100644
--- a/vllm/distributed/device_communicators/symm_mem.py
+++ b/vllm/distributed/device_communicators/symm_mem.py
@@ -131,7 +131,7 @@ class SymmMemCommunicator:
return None
if out is None:
out = torch.empty_like(inp)
- self.buffer[: inp.numel()].copy_(inp.view(-1))
+ self.buffer[: inp.numel()].copy_(inp.reshape(-1))
# Determine which algorithm to use
use_multimem = False
diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py
new file mode 100644
index 0000000000000..e4b4fc92eeaaa
--- /dev/null
+++ b/vllm/distributed/eplb/async_worker.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+The async worker that transfers experts in the background.
+"""
+
+import asyncio
+import threading
+from typing import TYPE_CHECKING
+
+import torch
+from torch.distributed import ProcessGroup
+
+from vllm.distributed.parallel_state import get_ep_group
+from vllm.logger import init_logger
+
+from .rebalance_execute import transfer_layer
+
+if TYPE_CHECKING:
+ from .eplb_state import EplbState
+
+logger = init_logger(__name__)
+
+
+def start_async_worker(
+ state: "EplbState",
+ rank_mapping: dict[int, int] | None = None,
+ is_profile: bool = False,
+) -> threading.Thread:
+ ep_group = get_ep_group().device_group
+ rank = ep_group.rank()
+ device_index = state.cuda_device_index
+
+ def thread_target() -> None:
+ assert device_index is not None
+ torch.cuda.set_device(device_index)
+ cuda_stream = torch.cuda.Stream(device=device_index)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(
+ transfer_run_periodically(
+ state=state,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ rank_mapping=rank_mapping,
+ cuda_stream=cuda_stream,
+ )
+ )
+ except Exception as exc: # pragma: no cover - diagnostic path
+ logger.exception("async loop error (Rank %d): %s", rank, str(exc))
+ finally:
+ loop.close()
+
+ thread = threading.Thread(target=thread_target, daemon=True)
+ thread.start()
+ return thread
+
+
+async def transfer_run_periodically(
+ state: "EplbState",
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ rank_mapping: dict[int, int] | None = None,
+ cuda_stream: torch.cuda.Stream = None,
+) -> None:
+ while True:
+ await asyncio.to_thread(state.rearrange_event.wait)
+ logger.info("async worker woke up for EPLB transfer")
+
+ for model_state in state.model_states.values():
+ if not model_state.is_async_enabled:
+ continue
+ current_num_layers = model_state.model.num_moe_layers
+ while (
+ model_state.rebalanced
+ and model_state.layer_to_transfer < current_num_layers
+ ):
+ if (
+ not model_state.ep_buffer_ready
+ and model_state.rebalanced
+ and model_state.new_physical_to_logical_map is not None
+ ):
+ await asyncio.to_thread(model_state.buffer_lock.acquire)
+ try:
+ if model_state.layer_to_transfer >= current_num_layers:
+ break
+
+ (
+ model_state.is_unchanged,
+ model_state.is_received_locally,
+ model_state.experts_recv_loc,
+ ) = await transfer_layer(
+ old_global_expert_indices=model_state.physical_to_logical_map,
+ new_global_expert_indices=model_state.new_physical_to_logical_map,
+ expert_weights=model_state.model.expert_weights,
+ expert_weights_buffer=model_state.expert_buffer,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ layer=model_state.layer_to_transfer,
+ cuda_stream=cuda_stream,
+ rank_mapping=rank_mapping,
+ )
+ event = torch.cuda.Event(blocking=False)
+ cuda_stream.record_event(event)
+ model_state.buffer_ready_event = event
+ model_state.ep_buffer_ready = 1
+ finally:
+ model_state.buffer_lock.release()
+ else:
+ if not model_state.rebalanced:
+ break
+ await asyncio.sleep(0.001)
+
+ state.rearrange_event.clear()
diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py
index 526d3ceac7b8f..9f8798a96a2fc 100644
--- a/vllm/distributed/eplb/eplb_state.py
+++ b/vllm/distributed/eplb/eplb_state.py
@@ -26,6 +26,7 @@ MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""
+import threading
import time
from collections.abc import Sequence
from dataclasses import dataclass
@@ -43,8 +44,9 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
+from .async_worker import start_async_worker
from .rebalance_algo import rebalance_experts
-from .rebalance_execute import rearrange_expert_weights_inplace
+from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
logger = init_logger(__name__)
@@ -132,6 +134,74 @@ class EplbModelState:
"""
model_name: str
model: MixtureOfExperts
+ expert_buffer: list[torch.Tensor]
+ """
+ The buffer to store the expert weights during transfer.
+ """
+ buffer_lock: threading.Lock
+ """
+ The lock to protect the expert buffer.
+ """
+ buffer_ready_event: torch.cuda.Event | None
+ """
+ CUDA event recorded when the async worker finishes filling the buffer.
+ The main thread waits on this before consuming the buffer.
+ """
+ ep_buffer_ready: int
+ """
+ The flag indicates whether the expert buffer is ready for transfer.
+ 0 or 1.
+ """
+ layer_to_transfer: int
+ """
+ The layer index to transfer in async mode.
+ """
+ rebalanced: bool
+ """
+ The flag indicates whether the experts rebalance have been computed.
+ """
+ pending_global_ready_check: bool
+ """
+ Whether the async EPLB needs to poll peers for buffer readiness.
+ """
+ is_unchanged: list[bool]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ is_received_locally: list[bool]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ experts_recv_loc: dict[int, int]
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ The size is same as the num of physical experts in the current layer.
+ """
+ is_async_enabled: bool
+ """
+ The flag indicates whether the EPLB is running in async mode.
+ """
+ cuda_device_index: int | None
+ """
+ CUDA device index for the async EPLB worker thread.
+ """
+ new_physical_to_logical_map: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as physical_to_logical_map
+ """
+ new_logical_to_physical_map: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as logical_to_physical_map
+ """
+ new_logical_replica_count: torch.Tensor | None = None
+ """
+ intermediate variable between `move_to_buffer` and `move_to_workspace`.
+ the size is same as logical_replica_count
+ """
class EplbState:
@@ -164,12 +234,31 @@ class EplbState:
Otherwise, the rearrangement will hang at collective
communication calls.
"""
- self.expert_rearrangement_step: int = 0
+ self.expert_rearrangement_step_interval: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
- self.expert_rearrangement_step_interval: int = 0
+ self.is_async: bool = False
+ """
+ The flag indicates whether the EPLB is running in async mode.
+ """
+ self.rearrange_event = threading.Event()
+ """
+ Event to signal when a new rearrangement is needed for the async thread.
+ """
+ self.async_worker: threading.Thread | None = None
+ """
+ Background thread handling async transfers.
+ """
+ self.cuda_device_index: int | None = None
+ """
+ CUDA device index for the async EPLB worker thread.
+ """
+ if self.device.type == "cuda":
+ self.cuda_device_index = self.device.index
+ if self.cuda_device_index is None and torch.cuda.is_available():
+ self.cuda_device_index = torch.cuda.current_device()
@staticmethod
def build_initial_global_physical_to_logical_map(
@@ -239,6 +328,8 @@ class EplbState:
Build the initial EPLB state.
"""
self.validate_ep_configuration(model)
+ self.is_async = self.parallel_config.eplb_config.use_async
+
physical_to_logical_map_list = (
EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
@@ -368,7 +459,12 @@ class EplbState:
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
+ else:
+ new_physical_to_logical_map = None
+ new_logical_to_physical_map = None
+
+ new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
@@ -385,15 +481,33 @@ class EplbState:
)
self.expert_rearrangement_step = 0
- self.model_states[model_config.compute_hash()] = EplbModelState(
- physical_to_logical_map,
- logical_to_physical_map,
- logical_replica_count,
- expert_load_pass,
- expert_load_window,
- model_config.model,
- model,
+ expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
+
+ model_state = EplbModelState(
+ physical_to_logical_map=physical_to_logical_map,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ expert_load_pass=expert_load_pass,
+ expert_load_window=expert_load_window,
+ model_name=model_config.model,
+ model=model,
+ expert_buffer=expert_buffer,
+ buffer_lock=threading.Lock(),
+ buffer_ready_event=None,
+ ep_buffer_ready=0,
+ layer_to_transfer=0,
+ rebalanced=False,
+ pending_global_ready_check=False,
+ is_unchanged=[],
+ is_received_locally=[],
+ experts_recv_loc={},
+ is_async_enabled=self.is_async,
+ cuda_device_index=self.cuda_device_index,
+ new_physical_to_logical_map=new_physical_to_logical_map,
+ new_logical_to_physical_map=new_logical_to_physical_map,
+ new_logical_replica_count=new_logical_replica_count,
)
+ self.model_states[model_config.compute_hash()] = model_state
def step(
self,
@@ -420,7 +534,7 @@ class EplbState:
- `max_tokens`: The maximum load across ranks.
- `balancedness`: The ratio of average load to maximum load.
"""
-
+ ep_group = get_ep_group().device_group
if is_profile:
self.rearrange(is_profile=True)
return
@@ -488,7 +602,49 @@ class EplbState:
# rearrangement step and perform rearrangement to ensure all ranks are
# performing collective communication.
self.expert_rearrangement_step += 1
+
+ if self.is_async:
+ for eplb_model_state in self.model_states.values():
+ if not eplb_model_state.is_async_enabled:
+ continue
+
+ all_ranks_buffer_ready = False
+ if eplb_model_state.pending_global_ready_check:
+ all_ranks_buffer_ready = self._all_ranks_buffer_ready(
+ eplb_model_state
+ )
+ if (
+ eplb_model_state.is_async_enabled
+ and eplb_model_state.ep_buffer_ready
+ and all_ranks_buffer_ready
+ ):
+ self.move_to_workspace(
+ model_state=eplb_model_state,
+ ep_group=ep_group,
+ is_profile=is_profile,
+ )
+ if (
+ eplb_model_state.layer_to_transfer
+ >= eplb_model_state.model.num_moe_layers
+ ):
+ self.post_eplb(eplb_model_state, is_profile)
+ eplb_model_state.rebalanced = False
+ eplb_model_state.layer_to_transfer = 0
+ eplb_model_state.pending_global_ready_check = False
+ logger.info(
+ "finish async transfer for model %s rank %d layer %d",
+ eplb_model_state.model_name,
+ ep_group.rank(),
+ eplb_model_state.model.num_moe_layers,
+ )
+
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
+ if any(
+ eplb_model_state.is_async_enabled and eplb_model_state.rebalanced
+ for eplb_model_state in self.model_states.values()
+ ):
+ # Still performing asynchronous rearrangement
+ return
self.expert_rearrangement_step = 0
self.rearrange()
@@ -524,7 +680,11 @@ class EplbState:
if is_main_rank:
torch.cuda.synchronize()
time_start = time.perf_counter()
- logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
+ logger.info(
+ "Rearranging experts %s %s...",
+ "(async mode)" if self.is_async else "sync mode",
+ "(profile)" if is_profile else "",
+ )
if global_expert_loads is None:
# Map the physical expert load to global logical experts
@@ -593,6 +753,7 @@ class EplbState:
model = eplb_model_state.model
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
+
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
@@ -608,7 +769,7 @@ class EplbState:
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
- self.num_nodes = 1
+ num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
@@ -631,60 +792,216 @@ class EplbState:
num_gpus,
)
- # Update expert weights
- rearrange_expert_weights_inplace(
- eplb_model_state.physical_to_logical_map,
- new_physical_to_logical_map,
- eplb_model_state.model.expert_weights,
- ep_group,
- is_profile,
- rank_mapping,
- )
+ if not eplb_model_state.is_async_enabled or is_profile:
+ # Update expert weights
+ rearrange_expert_weights_inplace(
+ eplb_model_state.physical_to_logical_map,
+ new_physical_to_logical_map,
+ eplb_model_state.model.expert_weights,
+ ep_group,
+ is_profile,
+ rank_mapping,
+ )
- if not is_profile:
- if (
- eplb_model_state.physical_to_logical_map.shape[1]
- != new_physical_to_logical_map.shape[1]
- ):
- eplb_model_state.physical_to_logical_map = (
- new_physical_to_logical_map.to(
- eplb_model_state.physical_to_logical_map.device
+ if not is_profile:
+ if (
+ eplb_model_state.physical_to_logical_map.shape[1]
+ != new_physical_to_logical_map.shape[1]
+ ):
+ eplb_model_state.physical_to_logical_map = (
+ new_physical_to_logical_map.to(
+ eplb_model_state.physical_to_logical_map.device
+ )
)
+ else:
+ eplb_model_state.physical_to_logical_map.copy_(
+ new_physical_to_logical_map
+ )
+ max_physical_slots = new_logical_to_physical_map.shape[-1]
+ assert (
+ max_physical_slots
+ <= eplb_model_state.logical_to_physical_map.shape[-1]
)
- else:
- eplb_model_state.physical_to_logical_map.copy_(
- new_physical_to_logical_map
+ new_logical_to_physical_map = torch.nn.functional.pad(
+ new_logical_to_physical_map,
+ (
+ 0,
+ eplb_model_state.logical_to_physical_map.shape[-1]
+ - max_physical_slots,
+ ),
+ value=-1,
)
- max_physical_slots = new_logical_to_physical_map.shape[-1]
- assert (
- max_physical_slots
- <= eplb_model_state.logical_to_physical_map.shape[-1]
- )
- new_logical_to_physical_map = torch.nn.functional.pad(
+ eplb_model_state.logical_to_physical_map.copy_(
+ new_logical_to_physical_map
+ )
+ eplb_model_state.logical_replica_count.copy_(
+ new_logical_replica_count
+ )
+ if is_main_rank:
+ assert time_start is not None
+ torch.cuda.synchronize()
+ time_end = time.perf_counter()
+ logger.info(
+ "Rearranged experts%sin %.2f seconds.",
+ " (profile) " if is_profile else " ",
+ time_end - time_start,
+ )
+ else:
+ device = eplb_model_state.physical_to_logical_map.device
+ new_physical = new_physical_to_logical_map.to(device)
+ max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
+ padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
- (
- 0,
- eplb_model_state.logical_to_physical_map.shape[-1]
- - max_physical_slots,
- ),
+ (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
+ ).to(eplb_model_state.logical_to_physical_map.device)
+ new_replica = new_logical_replica_count.to(
+ eplb_model_state.logical_replica_count.device
)
- eplb_model_state.logical_to_physical_map.copy_(
- new_logical_to_physical_map
- )
- eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
- if is_main_rank:
- assert time_start is not None
- torch.cuda.synchronize()
- time_end = time.perf_counter()
- logger.info(
- "Rearranged experts%sin %.2f seconds.",
- " (profile) " if is_profile else " ",
- time_end - time_start,
- )
+ eplb_model_state.new_physical_to_logical_map = new_physical
+ eplb_model_state.new_logical_to_physical_map = padded_logical
+ eplb_model_state.new_logical_replica_count = new_replica
+
+ eplb_model_state.rebalanced = True
+ eplb_model_state.layer_to_transfer = 0
+ eplb_model_state.pending_global_ready_check = True
+
+ # Signal async thread to start transferring layers
+ if self.is_async and (not is_profile):
+ self.rearrange_event.set()
return None
+ def start_async_loop(
+ self,
+ rank_mapping: dict[int, int] | None = None,
+ is_profile: bool = False,
+ ):
+ if not self.is_async:
+ return
+ if self.async_worker is None:
+ self.async_worker = start_async_worker(
+ self,
+ rank_mapping=rank_mapping,
+ is_profile=is_profile,
+ )
+
+ def _update_layer_mapping_from_new(
+ self, model_state: EplbModelState, layer: int
+ ) -> None:
+ if (
+ model_state.new_physical_to_logical_map is None
+ or model_state.new_logical_to_physical_map is None
+ or model_state.new_logical_replica_count is None
+ ):
+ return
+
+ target_device = model_state.physical_to_logical_map.device
+ new_physical = model_state.new_physical_to_logical_map
+ if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
+ model_state.physical_to_logical_map = new_physical.to(target_device)
+ else:
+ model_state.physical_to_logical_map[layer].copy_(
+ new_physical[layer].to(target_device)
+ )
+
+ logical_device = model_state.logical_to_physical_map.device
+ new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
+ max_slots = model_state.logical_to_physical_map.shape[-1]
+ slot_delta = max_slots - new_logical.shape[-1]
+ if slot_delta > 0:
+ new_logical = torch.nn.functional.pad(
+ new_logical, (0, slot_delta), value=-1
+ )
+ model_state.logical_to_physical_map[layer].copy_(new_logical)
+
+ replica_device = model_state.logical_replica_count.device
+ model_state.logical_replica_count[layer].copy_(
+ model_state.new_logical_replica_count[layer].to(replica_device)
+ )
+
+ def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
+ parallel_state = get_ep_group()
+ cpu_group = getattr(parallel_state, "cpu_group", None)
+ if cpu_group is not None and cpu_group.size() > 1:
+ flag = torch.tensor(
+ (int(model_state.ep_buffer_ready),), dtype=torch.int32, device="cpu"
+ )
+ all_reduce(flag, group=cpu_group)
+ return int(flag.item()) == cpu_group.size()
+
+ device_group = parallel_state.device_group
+ if device_group.size() <= 1:
+ return bool(model_state.ep_buffer_ready)
+
+ device = getattr(
+ parallel_state, "device", model_state.physical_to_logical_map.device
+ )
+ flag = torch.tensor(
+ (int(model_state.ep_buffer_ready),), dtype=torch.int32, device=device
+ )
+ all_reduce(flag, group=device_group)
+ return int(flag.item()) == device_group.size()
+
+ def move_to_workspace(
+ self,
+ model_state: EplbModelState,
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ ):
+ if not model_state.buffer_lock.acquire(blocking=False):
+ return
+ try:
+ assert model_state.new_physical_to_logical_map is not None
+ device_index = model_state.cuda_device_index or self.cuda_device_index
+ if model_state.buffer_ready_event is not None and device_index is not None:
+ stream = torch.cuda.current_stream(device=device_index)
+ stream.wait_event(model_state.buffer_ready_event)
+ model_state.buffer_ready_event = None
+ move_from_buffer(
+ expert_weights=model_state.model.expert_weights[
+ model_state.layer_to_transfer
+ ],
+ expert_weights_buffer=model_state.expert_buffer,
+ is_unchanged=model_state.is_unchanged,
+ is_received_locally=model_state.is_received_locally,
+ experts_recv_loc=model_state.experts_recv_loc,
+ new_indices=model_state.new_physical_to_logical_map[
+ model_state.layer_to_transfer
+ ].tolist(),
+ ep_group=ep_group,
+ )
+ transferred_layer = model_state.layer_to_transfer
+ self._update_layer_mapping_from_new(model_state, transferred_layer)
+ # After the main thread consumes, advance layer_to_transfer
+ model_state.layer_to_transfer += 1
+ model_state.ep_buffer_ready = 0
+ logger.info(
+ "model %s successfully move_to_workspace layer %d",
+ model_state.model_name,
+ transferred_layer,
+ )
+ finally:
+ try:
+ model_state.buffer_lock.release()
+ except Exception as e:
+ logger.error(
+ "Rank %d: buffer_lock release failed in move_to_workspace: %s",
+ ep_group.rank(),
+ str(e),
+ )
+
+ def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None:
+ assert model_state.new_physical_to_logical_map is not None
+ assert model_state.new_logical_to_physical_map is not None
+ assert model_state.new_logical_replica_count is not None
+ if not is_profile:
+ for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
+ self._update_layer_mapping_from_new(model_state, layer_idx)
+ model_state.new_physical_to_logical_map = None
+ model_state.new_logical_to_physical_map = None
+ model_state.new_logical_replica_count = None
+
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py
index 5c1efbaf03bab..376dad8a72ef1 100644
--- a/vllm/distributed/eplb/rebalance_execute.py
+++ b/vllm/distributed/eplb/rebalance_execute.py
@@ -100,18 +100,19 @@ def get_ep_ranks_with_expert(
return ranks_to_send, ranks_to_recv_actual
-def shuffle_layer(
+def move_to_buffer(
num_local_experts: int,
- ep_rank: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
+ cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
-) -> None:
+) -> tuple[list[bool], list[bool], dict[int, int]]:
"""
Perform expert weights rearrangement of one layer.
"""
+ ep_rank = ep_group.rank()
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
@@ -137,7 +138,8 @@ def shuffle_layer(
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- buffer[dst].copy_(weight[src])
+ with torch.cuda.stream(cuda_stream):
+ buffer[dst].copy_(weight[src], non_blocking=True)
p2p_ops: list[P2POp] = []
@@ -225,25 +227,115 @@ def shuffle_layer(
]
# 4. Execute the P2P operations. The real communication happens here.
- if p2p_ops:
+ if p2p_ops and cuda_stream is not None:
+ with torch.cuda.stream(cuda_stream):
+ reqs = batch_isend_irecv(p2p_ops)
+ for req in reqs:
+ req.wait()
+ elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
+ # wait for the communication to finish
+ return is_unchanged, is_received_locally, experts_recv_loc
+
+
+def move_from_buffer(
+ expert_weights: Iterable[torch.Tensor],
+ expert_weights_buffer: list[torch.Tensor],
+ is_unchanged: list[bool],
+ is_received_locally: list[bool],
+ experts_recv_loc: dict[int, int],
+ new_indices: Sequence[int],
+ ep_group: ProcessGroup,
+) -> None:
+ ep_rank = ep_group.rank()
+ num_local_experts = len(is_unchanged)
+
+ local2global = partial(
+ idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
+ )
- # 5. Copy the weights from the buffer back to the original weights.
for dst in range(num_local_experts):
if is_unchanged[dst]:
continue
if is_received_locally[dst]:
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- weight[dst].copy_(buffer[dst])
+ weight[dst].copy_(buffer[dst], non_blocking=True)
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
- weight[dst].copy_(buffer[src])
+ weight[dst].copy_(buffer[src], non_blocking=True)
+
+
+async def transfer_layer(
+ old_global_expert_indices: torch.Tensor,
+ new_global_expert_indices: torch.Tensor,
+ expert_weights: Sequence[Iterable[torch.Tensor]],
+ expert_weights_buffer: Sequence[torch.Tensor],
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+ layer: int = 0,
+ cuda_stream: torch.cuda.Stream | None = None,
+ rank_mapping: dict[int, int] | None = None,
+) -> tuple[list[bool], list[bool], dict[int, int]]:
+ """
+ Rearranges the expert weights in place according to the new expert indices.
+
+ The value of the indices arguments are logical indices of the experts,
+ while keys are physical.
+
+ Args:
+ old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ expert_weights: A sequence of shape (num_moe_layers)(weight_count)
+ of tensors of shape (num_local_physical_experts, hidden_size_i).
+ For example, a linear layer may have up and down projection,
+ so weight_count = 2. Each weight's hidden size can be different.
+ ep_group: The device process group for expert parallelism.
+ is_profile (bool): If `True`, do not perform any actual weight copy.
+ This is used during profile run, where we only perform dummy
+ communications to reserve enough memory for the buffers.
+ """
+ ep_size = ep_group.size()
+ if rank_mapping is not None:
+ if len(rank_mapping) == ep_group.size():
+ # scale down
+ new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
+ new_global_expert_indices,
+ rank_mapping,
+ )
+ else:
+ # scale up
+ old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
+ old_global_expert_indices,
+ rank_mapping,
+ ep_group.size(),
+ )
+
+ assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
+ num_moe_layers, num_physical_experts = old_global_expert_indices.shape
+ assert len(expert_weights) == num_moe_layers
+ num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
+ assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
+ assert num_physical_experts == ep_size * num_local_physical_experts
+ # A buffer to hold the expert weights in one layer during the exchange.
+ # NOTE: Currently we assume the same weights across different layers
+ # have the same shape.
+
+ is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
+ num_local_experts=num_local_physical_experts,
+ old_indices=old_global_expert_indices[layer].tolist(),
+ new_indices=new_global_expert_indices[layer].tolist(),
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ cuda_stream=cuda_stream,
+ ep_group=ep_group,
+ )
+ return is_unchanged, is_received_locally, experts_recv_loc
def rearrange_expert_weights_inplace(
@@ -296,7 +388,6 @@ def rearrange_expert_weights_inplace(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
- ep_rank = ep_group.rank()
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
@@ -329,14 +420,24 @@ def rearrange_expert_weights_inplace(
torch.cuda.synchronize()
for layer in range(num_moe_layers):
- shuffle_layer(
- num_local_physical_experts,
- ep_rank,
- old_global_expert_indices_cpu[layer].tolist(),
- new_global_expert_indices_cpu[layer].tolist(),
- expert_weights[layer],
- expert_weights_buffer,
- ep_group,
+ is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
+ num_local_experts=num_local_physical_experts,
+ old_indices=old_global_expert_indices_cpu[layer].tolist(),
+ new_indices=new_global_expert_indices_cpu[layer].tolist(),
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ cuda_stream=None,
+ ep_group=ep_group,
+ )
+
+ move_from_buffer(
+ expert_weights=expert_weights[layer],
+ expert_weights_buffer=expert_weights_buffer,
+ is_unchanged=is_unchanged,
+ is_received_locally=is_received_locally,
+ experts_recv_loc=experts_recv_loc,
+ new_indices=new_global_expert_indices[layer].tolist(),
+ ep_group=ep_group,
)
@@ -428,4 +529,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices
-__all__ = ["rearrange_expert_weights_inplace"]
+__all__ = ["transfer_layer", "move_from_buffer"]
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
index f85eb414b2222..74f09278b7bb1 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
@@ -38,7 +38,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
-from typing import TYPE_CHECKING, Any, Literal, Optional
+from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch
@@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionMetadata
+ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC):
+ """
+ Base class for KV connectors.
+
+ Attributes:
+ prefer_cross_layer_blocks (bool): Indicates whether this connector
+ prefers KV blocks that hold KV data for all layers (for speeding
+ up KV data transfers).
+ Defaults to False.
+ """
+
+ prefer_cross_layer_blocks: ClassVar[bool] = False
+
def __init__(
self,
vllm_config: "VllmConfig",
@@ -226,6 +238,23 @@ class KVConnectorBase_V1(ABC):
"""
return
+ def register_cross_layers_kv_cache(
+ self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
+ ):
+ """
+ Initialize with a single KV cache tensor used by all layers.
+ The first dimension should be num_layers.
+ This function will only be called for models with uniform layers,
+ and only if the prefers_cross_layer_blocks is set to True.
+ Only one of the functions
+ {register_kv_caches, register_cross_layers_kv_cache} will be called.
+
+ Args:
+ kv_cache: a cross-layers kv cache tensor
+ attn_backend: The attention backend that corresponds to all layers
+ """
+ return
+
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
index ab2eeed9f6b8a..6acfb73997f25 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
@@ -310,7 +310,6 @@ class LMCacheMPWorkerAdapter:
request_id,
result,
)
- logger.info("Retrieve request for request_id=%s finished", request_id)
# Remove the finished requests from the tracking dicts
for request_id in finished_stores:
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
index 55831dc56c803..d1d3e475cc889 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
import zmq
+from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.config import VllmConfig
@@ -60,17 +61,44 @@ def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
return block_ids[0]
+def extract_world_size_and_kv_rank(
+ world_size: int,
+ rank: int,
+ vllm_config: VllmConfig,
+) -> tuple[int, int]:
+ """
+ Convert the rank for the MLA.
+ """
+ use_mla = mla_enabled(vllm_config.model_config)
+ if not use_mla:
+ return world_size, rank
+ else:
+ # Tensor parallel does not change the KV caches for MLA models.
+ # So we need to "exclude" the effect of TP on rank and world size
+ tp_size = vllm_config.parallel_config.tensor_parallel_size
+ # vLLM constructs TP groups first, and then construct other
+ # parallel groups on top of TP groups.
+ # for example, TP=4, PP=2,
+ # TP group: [0, 1, 2, 3], [4, 5, 6, 7]
+ # PP group: [0, 4], [1, 5], [2, 6], [3, 7]
+ # So we can "exclude" the effect of TP by rank // tp_size.
+ return world_size // tp_size, rank // tp_size
+
+
def create_scheduler_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPSchedulerAdapter:
- # TODO: have a helper function to calculate the correct rank and
- # world size for the MLA and other models
+ world_size, kv_rank = extract_world_size_and_kv_rank(
+ vllm_config.parallel_config.world_size,
+ vllm_config.parallel_config.rank,
+ vllm_config,
+ )
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
- vllm_config.parallel_config.world_size,
- vllm_config.parallel_config.rank,
+ world_size,
+ kv_rank,
vllm_config.cache_config.block_size,
)
@@ -78,14 +106,17 @@ def create_scheduler_adapter(
def create_worker_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPWorkerAdapter:
- # TODO: have a helper function to calculate the correct rank and
- # world size for the MLA and other models
+ world_size, kv_rank = extract_world_size_and_kv_rank(
+ vllm_config.parallel_config.world_size,
+ vllm_config.parallel_config.rank,
+ vllm_config,
+ )
return LMCacheMPWorkerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
- vllm_config.parallel_config.world_size,
- vllm_config.parallel_config.rank,
+ world_size,
+ kv_rank,
vllm_config.cache_config.block_size,
)
@@ -438,9 +469,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
ops.append(meta.op)
if len(request_ids) > 0:
- logger.info(
- "HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s", request_ids
- )
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event
)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index 1626f819af8b5..493938d4aad92 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -4,7 +4,6 @@ import contextlib
import copy
import logging
import math
-import os
import queue
import threading
import time
@@ -810,9 +809,6 @@ class NixlConnectorWorker:
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
- # TODO temporary, once nixl allows for telemetry flag in config
- # (next release), we can remove this env var.
- os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
@@ -828,10 +824,11 @@ class NixlConnectorWorker:
if nixl_agent_config is None:
config = None
else:
+ # Enable telemetry by default for NIXL 0.7.1 and above.
config = (
- nixl_agent_config(backends=self.nixl_backends)
+ nixl_agent_config(backends=self.nixl_backends, capture_telemetry=True)
if len(non_ucx_backends) > 0
- else nixl_agent_config(num_threads=num_threads)
+ else nixl_agent_config(num_threads=num_threads, capture_telemetry=True)
)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
@@ -1042,10 +1039,12 @@ class NixlConnectorWorker:
NOT directly supported by NIXL (e.g., tpu)
"""
xfer_buffers: dict[str, torch.Tensor] = {}
+ inv_order = [0, 1, 3, 2, 4]
try:
for layer_name, kv_cache in kv_caches.items():
kv_shape = kv_cache.shape
kv_dtype = kv_cache.dtype
+ permute_shape = False
if (
self.kv_cache_layout == "NHD"
and self.vllm_config.kv_transfer_config is not None
@@ -1059,10 +1058,20 @@ class NixlConnectorWorker:
# Since NHD will not support Decode/Prefill TP_ratio > 1,
# we can leverage host_buffer for permute
self.host_buffer_kv_cache_layout = "HND"
- kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
+ kv_shape = (
+ tuple(kv_shape[i] for i in inv_order)
+ if not self.use_mla
+ else kv_shape
+ )
+ permute_shape = not self.use_mla
+
xfer_buffers[layer_name] = torch.empty(
kv_shape, dtype=kv_dtype, device="cpu"
)
+ if permute_shape:
+ xfer_buffers[layer_name] = xfer_buffers[layer_name].permute(
+ inv_order
+ )
except MemoryError as e:
logger.error("NIXLConnectorWorker gets %s.", e)
raise
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
index 582e42cc466ae..8cd09014cab11 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
@@ -4,12 +4,12 @@ from collections import defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
-from typing import Any
+from typing import Any, ClassVar
import torch
-from vllm.attention import AttentionMetadata
-from vllm.config import VllmConfig
+from vllm.attention import Attention, AttentionBackend, AttentionMetadata
+from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
@@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class OffloadingConnector(KVConnectorBase_V1):
+ prefer_cross_layer_blocks: ClassVar[bool] = True
+
def __init__(
self,
vllm_config: VllmConfig,
@@ -63,6 +65,12 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
+ def register_cross_layers_kv_cache(
+ self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
+ ):
+ assert self.connector_worker is not None
+ self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
+
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
@@ -422,10 +430,35 @@ class OffloadingConnectorWorker:
self._job_counter = job_id + 1
return job_id
- def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
- for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches):
+ def _register_handlers(
+ self,
+ kv_caches: dict[str, torch.Tensor],
+ attn_backends: dict[str, type[AttentionBackend]],
+ ):
+ for src_cls, dst_cls, handler in self.spec.get_handlers(
+ kv_caches, attn_backends
+ ):
self.worker.register_handler(src_cls, dst_cls, handler)
+ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
+ layer_names = list(kv_caches.keys())
+ layers = get_layers_from_vllm_config(
+ self.spec.vllm_config, Attention, layer_names
+ )
+ attn_backends = {
+ layer_name: layers[layer_name].get_attn_backend()
+ for layer_name in layer_names
+ }
+ self._register_handlers(kv_caches, attn_backends)
+
+ def register_cross_layers_kv_cache(
+ self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
+ ):
+ cross_layer_name = "ALL_LAYERS"
+ kv_caches = {cross_layer_name: kv_cache}
+ attn_backends = {cross_layer_name: attn_backend}
+ self._register_handlers(kv_caches, attn_backends)
+
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 852c4c644433f..f81612fd1f4a3 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -1098,6 +1098,12 @@ get_context_model_parallel_group = get_dcp_group
_PP: GroupCoordinator | None = None
+
+def get_pp_group() -> GroupCoordinator:
+ assert _PP is not None, "pipeline model parallel group is not initialized"
+ return _PP
+
+
_DP: GroupCoordinator | None = None
@@ -1114,9 +1120,12 @@ def get_ep_group() -> GroupCoordinator:
return _EP
-def get_pp_group() -> GroupCoordinator:
- assert _PP is not None, "pipeline model parallel group is not initialized"
- return _PP
+_PCP: GroupCoordinator | None = None
+
+
+def get_pcp_group() -> GroupCoordinator:
+ assert _PCP is not None, "prefill context parallel group is not initialized"
+ return _PCP
@deprecated(
@@ -1276,6 +1285,7 @@ def init_distributed_environment(
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
+ prefill_context_model_parallel_size: int = 1,
decode_context_model_parallel_size: int | None = 1,
backend: str | None = None,
) -> None:
@@ -1325,7 +1335,11 @@ def initialize_model_parallel(
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
- -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
+ -1,
+ data_parallel_size,
+ pipeline_model_parallel_size,
+ prefill_context_model_parallel_size,
+ tensor_model_parallel_size,
) # noqa
# Build the tensor model-parallel groups.
@@ -1360,11 +1374,23 @@ def initialize_model_parallel(
group_name="dcp",
)
+ global _PCP
+ assert _PCP is None, "prefill context parallel group is already initialized"
+ group_ranks = (
+ all_ranks.transpose(3, 4)
+ .reshape(-1, prefill_context_model_parallel_size)
+ .unbind(0)
+ )
+ group_ranks = [x.tolist() for x in group_ranks]
+ _PCP = init_model_parallel_group(
+ group_ranks, get_world_group().local_rank, backend, group_name="pcp"
+ )
+
# Build the pipeline model-parallel groups.
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = (
- all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
+ all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
@@ -1373,7 +1399,7 @@ def initialize_model_parallel(
global _DP
assert _DP is None, "data parallel group is already initialized"
- group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
+ group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
@@ -1383,7 +1409,12 @@ def initialize_model_parallel(
assert _EP is None, "expert parallel group is already initialized"
group_ranks = (
all_ranks.transpose(1, 2)
- .reshape(-1, data_parallel_size * tensor_model_parallel_size)
+ .reshape(
+ -1,
+ data_parallel_size
+ * prefill_context_model_parallel_size
+ * tensor_model_parallel_size,
+ )
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
@@ -1393,11 +1424,13 @@ def initialize_model_parallel(
logger.info_once(
"rank %s in world size %s is assigned as "
- "DP rank %s, PP rank %s, TP rank %s, EP rank %s",
+ "DP rank %s, PP rank %s, PCP rank %s, "
+ "TP rank %s, EP rank %s",
rank,
world_size,
_DP.rank_in_group,
_PP.rank_in_group,
+ _PCP.rank_in_group,
_TP.rank_in_group,
_EP.rank_in_group,
)
@@ -1406,6 +1439,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
+ prefill_context_model_parallel_size: int = 1,
decode_context_model_parallel_size: int | None = 1,
backend: str | None = None,
) -> None:
@@ -1418,6 +1452,7 @@ def ensure_model_parallel_initialized(
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
+ prefill_context_model_parallel_size,
decode_context_model_parallel_size,
backend,
)
@@ -1434,6 +1469,12 @@ def ensure_model_parallel_initialized(
f"got: {pp_world_size=} vs. "
f"wanted: {pipeline_model_parallel_size=}"
)
+ pcp_world_size = get_pcp_group().world_size
+ assert pcp_world_size == prefill_context_model_parallel_size, (
+ "prefill context parallel group already initialized, but of unexpected size: "
+ f"{pcp_world_size=} vs. "
+ f"{prefill_context_model_parallel_size=}"
+ )
def prepare_communication_buffer_for_model(model: torch.nn.Module):
@@ -1445,6 +1486,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
"""
if _TP is not None:
_TP.prepare_communication_buffer_for_model(model)
+ if _PCP is not None:
+ _PCP.prepare_communication_buffer_for_model(model)
if _PP is not None:
_PP.prepare_communication_buffer_for_model(model)
if _DP is not None:
@@ -1520,16 +1563,21 @@ def destroy_model_parallel():
_TP.destroy()
_TP = None
- global _PP
- if _PP:
- _PP.destroy()
- _PP = None
-
global _DCP
if _DCP:
_DCP.destroy()
_DCP = None
+ global _PCP
+ if _PCP:
+ _PCP.destroy()
+ _PCP = None
+
+ global _PP
+ if _PP:
+ _PP.destroy()
+ _PP = None
+
global _DP
if _DP:
_DP.destroy()
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index ab6e5e594c239..6b5c8ba87ecbf 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -77,7 +77,7 @@ from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
from vllm.config.utils import get_field
-from vllm.logger import init_logger
+from vllm.logger import init_logger, suppress_logging
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
@@ -247,11 +247,13 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
default = field.default
# Handle pydantic.Field defaults
if isinstance(default, FieldInfo):
- default = (
- default.default
- if default.default_factory is None
- else default.default_factory()
- )
+ if default.default_factory is None:
+ default = default.default
+ else:
+ # VllmConfig's Fields have default_factory set to config classes.
+ # These could emit logs on init, which would be confusing.
+ with suppress_logging():
+ default = default.default_factory()
elif field.default_factory is not MISSING:
default = field.default_factory()
@@ -367,7 +369,7 @@ class EngineArgs:
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
- seed: int | None = ModelConfig.seed
+ seed: int | None = 0
max_model_len: int | None = ModelConfig.max_model_len
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
cudagraph_capture_sizes: list[int] | None = (
@@ -389,8 +391,10 @@ class EngineArgs:
nnodes: int = ParallelConfig.nnodes
node_rank: int = ParallelConfig.node_rank
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
+ prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
+ cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None
@@ -423,7 +427,7 @@ class EngineArgs:
ParallelConfig.max_parallel_loading_workers
)
block_size: BlockSize | None = CacheConfig.block_size
- enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching
+ enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo
)
@@ -482,11 +486,9 @@ class EngineArgs:
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
- lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
- num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
@@ -502,11 +504,6 @@ class EngineArgs:
)
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
reasoning_parser_plugin: str | None = None
- # Deprecated guided decoding fields
- guided_decoding_backend: str | None = None
- guided_decoding_disable_fallback: bool | None = None
- guided_decoding_disable_any_whitespace: bool | None = None
- guided_decoding_disable_additional_properties: bool | None = None
logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
@@ -725,19 +722,6 @@ class EngineArgs:
"--reasoning-parser-plugin",
**structured_outputs_kwargs["reasoning_parser_plugin"],
)
- # Deprecated guided decoding arguments
- for arg, type in [
- ("--guided-decoding-backend", str),
- ("--guided-decoding-disable-fallback", bool),
- ("--guided-decoding-disable-any-whitespace", bool),
- ("--guided-decoding-disable-additional-properties", bool),
- ]:
- structured_outputs_group.add_argument(
- arg,
- type=type,
- help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
- deprecated=True,
- )
# Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig)
@@ -770,6 +754,15 @@ class EngineArgs:
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
)
+ parallel_group.add_argument(
+ "--cp-kv-cache-interleave-size",
+ **parallel_kwargs["cp_kv_cache_interleave_size"],
+ )
+ parallel_group.add_argument(
+ "--prefill-context-parallel-size",
+ "-pcp",
+ **parallel_kwargs["prefill_context_parallel_size"],
+ )
parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
)
@@ -846,30 +839,6 @@ class EngineArgs:
"--expert-placement-strategy",
**parallel_kwargs["expert_placement_strategy"],
)
- parallel_group.add_argument(
- "--num-redundant-experts",
- type=int,
- help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-window-size",
- type=int,
- help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-step-interval",
- type=int,
- help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
- deprecated=True,
- )
- parallel_group.add_argument(
- "--eplb-log-balancedness",
- action=argparse.BooleanOptionalAction,
- help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
- deprecated=True,
- )
parallel_group.add_argument(
"--max-parallel-loading-workers",
@@ -1001,9 +970,6 @@ class EngineArgs:
)
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
- lora_group.add_argument(
- "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
- )
lora_group.add_argument(
"--lora-dtype",
**lora_kwargs["lora_dtype"],
@@ -1070,9 +1036,6 @@ class EngineArgs:
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"],
)
- scheduler_group.add_argument(
- "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
- )
# multi-step scheduling has been removed; corresponding arguments
# are no longer supported.
scheduler_group.add_argument(
@@ -1185,29 +1148,52 @@ class EngineArgs:
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
+ # NOTE(woosuk): In V1, we use separate processes for workers (unless
+ # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
+ # doesn't affect the user process.
+ if self.seed is None:
+ logger.warning_once(
+ "`seed=None` is equivalent to `seed=0` in V1 Engine. "
+ "You will no longer be allowed to pass `None` in v0.13.",
+ scope="local",
+ )
+
+ self.seed = 0
+ if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
+ logger.warning(
+ "The global random seed is set to %d. Since "
+ "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
+ "affect the random state of the Python process that "
+ "launched vLLM.",
+ self.seed,
+ )
+
if self.disable_mm_preprocessor_cache:
- logger.warning(
+ logger.warning_once(
"`--disable-mm-preprocessor-cache` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb 0` instead.",
+ scope="local",
)
self.mm_processor_cache_gb = 0
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
- logger.warning(
+ logger.warning_once(
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb %d` instead.",
envs.VLLM_MM_INPUT_CACHE_GIB,
+ scope="local",
)
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
if self.enable_multimodal_encoder_data_parallel:
- logger.warning(
+ logger.warning_once(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
- "Please use `--mm-encoder-tp-mode data` instead."
+ "Please use `--mm-encoder-tp-mode data` instead.",
+ scope="local",
)
self.mm_encoder_tp_mode = "data"
@@ -1366,11 +1352,10 @@ class EngineArgs:
# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
# Disable chunked prefill and prefix caching for:
- # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
+ # POWER (ppc64le)/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
- CpuArchEnum.ARM,
CpuArchEnum.RISCV,
):
logger.info(
@@ -1500,7 +1485,7 @@ class EngineArgs:
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_rank is not None, (
- "data_parallel_rank or node_rank must be spefified if "
+ "data_parallel_rank or node_rank must be specified if "
"data_parallel_external_lb is enable."
)
assert self.data_parallel_size_local in (1, None), (
@@ -1587,6 +1572,12 @@ class EngineArgs:
model_config.skip_tokenizer_init = True
logger.info("Skipping tokenizer initialization for tokens-only mode.")
+ if self.async_scheduling and not self.disable_nccl_for_dp_synchronization:
+ logger.info(
+ "Disabling NCCL for DP synchronization when using async scheduling."
+ )
+ self.disable_nccl_for_dp_synchronization = True
+
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
@@ -1600,6 +1591,7 @@ class EngineArgs:
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
+ prefill_context_parallel_size=self.prefill_context_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
@@ -1631,6 +1623,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
@@ -1640,18 +1633,11 @@ class EngineArgs:
target_parallel_config=parallel_config,
)
- # make sure num_lookahead_slots is set appropriately depending on
- # whether speculative decoding is enabled
- num_lookahead_slots = self.num_lookahead_slots
- if speculative_config is not None:
- num_lookahead_slots = speculative_config.num_lookahead_slots
-
scheduler_config = SchedulerConfig(
runner_type=model_config.runner_type,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
- num_lookahead_slots=num_lookahead_slots,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_chunked_mm_input=self.disable_chunked_mm_input,
is_multimodal_model=model_config.is_multimodal_model,
@@ -1678,7 +1664,6 @@ class EngineArgs:
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
- lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0
@@ -1717,21 +1702,6 @@ class EngineArgs:
self.reasoning_parser_plugin
)
- # Forward the deprecated CLI args to the StructuredOutputsConfig
- so_config = self.structured_outputs_config
- if self.guided_decoding_backend is not None:
- so_config.guided_decoding_backend = self.guided_decoding_backend
- if self.guided_decoding_disable_fallback is not None:
- so_config.disable_fallback = self.guided_decoding_disable_fallback
- if self.guided_decoding_disable_any_whitespace is not None:
- so_config.disable_any_whitespace = (
- self.guided_decoding_disable_any_whitespace
- )
- if self.guided_decoding_disable_additional_properties is not None:
- so_config.disable_additional_properties = (
- self.guided_decoding_disable_additional_properties
- )
-
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
otlp_traces_endpoint=self.otlp_traces_endpoint,
@@ -1952,6 +1922,16 @@ class EngineArgs:
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
+ if self.prefill_context_parallel_size > 1:
+ default_chunked_prefill = False
+ default_prefix_caching = False
+ logger.warning_once(
+ "--prefill-context-parallel-size > 1 is not compatible with "
+ "chunked prefill and prefix caching now. Chunked prefill "
+ "and prefix caching have been disabled by default.",
+ scope="local",
+ )
+
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
@@ -1959,15 +1939,27 @@ class EngineArgs:
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
+ elif (
+ model_config.runner_type == "generate"
+ and not self.enable_chunked_prefill
+ and default_chunked_prefill
+ ):
+ logger.warning_once(
+ "This model does not officially support disabling chunked prefill. "
+ "Disabling this manually may cause the engine to crash "
+ "or produce incorrect outputs.",
+ scope="local",
+ )
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
- logger.warning(
+ logger.warning_once(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
+ scope="local",
)
if self.enable_prefix_caching is None:
@@ -1982,10 +1974,11 @@ class EngineArgs:
and self.enable_prefix_caching
and not default_prefix_caching
):
- logger.warning(
+ logger.warning_once(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
+ scope="local",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py
index 462d2c4e50e73..5e3374f9f6a10 100644
--- a/vllm/engine/protocol.py
+++ b/vllm/engine/protocol.py
@@ -149,6 +149,33 @@ class EngineClient(ABC):
"""Load a new LoRA adapter into the engine for future requests."""
...
+ @abstractmethod
+ async def pause_generation(
+ self,
+ *,
+ wait_for_inflight_requests: bool = False,
+ clear_cache: bool = True,
+ ) -> None:
+ """Pause new generation/encoding requests.
+
+ Args:
+ wait_for_inflight_requests: When ``True`` waits for in-flight requests
+ to finish before pausing. When ``False`` (default), aborts in-flight
+ requests immediately.
+ clear_cache: Whether to clear KV and prefix caches after draining.
+ """
+ ...
+
+ @abstractmethod
+ async def resume_generation(self) -> None:
+ """Resume accepting generation/encoding requests."""
+ ...
+
+ @abstractmethod
+ async def is_paused(self) -> bool:
+ """Return whether the engine is currently paused."""
+ ...
+
async def scale_elastic_ep(
self, new_data_parallel_size: int, drain_timeout: int = 300
) -> None:
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 3b722c2d92770..bf80856c1bbfc 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -94,6 +94,22 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
"""
+class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
+ audio_embeds: str | dict[str, str] | None
+ """
+ The audio embeddings. It can be either:
+ - A single base64 string representing a serialized torch tensor.
+ - A dictionary where each value is a base64 string.
+ """
+ type: Required[Literal["audio_embeds"]]
+ """The type of the content part."""
+ uuid: str | None
+ """
+ User-provided UUID of a media. User must guarantee that it is properly
+ generated and unique for different medias.
+ """
+
+
class VideoURL(TypedDict, total=False):
url: Required[str]
"""
@@ -211,6 +227,7 @@ ChatCompletionContentPartParam: TypeAlias = (
| CustomChatCompletionContentPILImageParam
| CustomChatCompletionContentSimpleImageParam
| ChatCompletionContentPartImageEmbedsParam
+ | ChatCompletionContentPartAudioEmbedsParam
| CustomChatCompletionContentSimpleAudioParam
| CustomChatCompletionContentSimpleVideoParam
| str
@@ -599,7 +616,7 @@ def resolve_chat_template_content_format(
return detected_format
-ModalityStr = Literal["image", "audio", "video", "image_embeds"]
+ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T = TypeVar("_T")
@@ -684,6 +701,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
+ if "audio_embeds" in uuids_by_modality:
+ audio_embeds_uuids = uuids_by_modality["audio_embeds"]
+ if len(audio_embeds_uuids) > 1:
+ raise ValueError("Only one message can have {'type': 'audio_embeds'}")
+ mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality:
@@ -703,6 +725,8 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
+ if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
+ raise ValueError("Mixing raw audio and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
@@ -711,6 +735,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
+ if "audio_embeds" in items_by_modality:
+ audio_embeds_lst = items_by_modality["audio_embeds"]
+ if len(audio_embeds_lst) > 1:
+ raise ValueError("Only one message can have {'type': 'audio_embeds'}")
+ mm_inputs["audio"] = audio_embeds_lst[0]
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
@@ -738,6 +767,8 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
+ if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
+ raise ValueError("Mixing raw audio and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
@@ -746,6 +777,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
+ if "audio_embeds" in items_by_modality:
+ audio_embeds_lst = items_by_modality["audio_embeds"]
+ if len(audio_embeds_lst) > 1:
+ raise ValueError("Only one message can have {'type': 'audio_embeds'}")
+ mm_inputs["audio"] = audio_embeds_lst[0]
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
@@ -804,6 +840,14 @@ class BaseMultiModalContentParser(ABC):
) -> None:
raise NotImplementedError
+ @abstractmethod
+ def parse_audio_embeds(
+ self,
+ audio_embeds: str | dict[str, str] | None,
+ uuid: str | None = None,
+ ) -> None:
+ raise NotImplementedError
+
@abstractmethod
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
raise NotImplementedError
@@ -861,6 +905,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("image", placeholder)
+ def parse_audio_embeds(
+ self,
+ audio_embeds: str | dict[str, str] | None,
+ uuid: str | None = None,
+ ) -> None:
+ mm_config = self.model_config.get_multimodal_config()
+ if not mm_config.enable_mm_embeds:
+ raise ValueError(
+ "You must set `--enable-mm-embeds` to input `audio_embeds`"
+ )
+
+ if isinstance(audio_embeds, dict):
+ embeds = {
+ k: self._connector.fetch_audio_embedding(v)
+ for k, v in audio_embeds.items()
+ }
+ placeholder = self._tracker.add("audio_embeds", embeds, uuid)
+ elif isinstance(audio_embeds, str):
+ embedding = self._connector.fetch_audio_embedding(audio_embeds)
+ placeholder = self._tracker.add("audio_embeds", embedding, uuid)
+ else:
+ placeholder = self._tracker.add("audio_embeds", None, uuid)
+
+ self._add_placeholder("audio", placeholder)
+
def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
@@ -950,6 +1019,67 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image_embeds", future, uuid)
self._add_placeholder("image", placeholder)
+ def parse_audio_embeds(
+ self,
+ audio_embeds: str | dict[str, str] | None,
+ uuid: str | None = None,
+ ) -> None:
+ mm_config = self.model_config.get_multimodal_config()
+ if not mm_config.enable_mm_embeds:
+ raise ValueError(
+ "You must set `--enable-mm-embeds` to input `audio_embeds`"
+ )
+
+ logger.info(
+ "🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
+ "is_str=%s, is_none=%s",
+ type(audio_embeds).__name__,
+ uuid,
+ isinstance(audio_embeds, dict),
+ isinstance(audio_embeds, str),
+ audio_embeds is None,
+ )
+
+ future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
+
+ if isinstance(audio_embeds, dict):
+ logger.info(
+ "🎵 Processing dict audio_embeds with %d entries",
+ len(audio_embeds),
+ )
+ embeds = {
+ k: self._connector.fetch_audio_embedding(v)
+ for k, v in audio_embeds.items()
+ }
+ future.set_result(embeds)
+ logger.info(
+ "🎵 Successfully loaded %d audio embeddings from dict",
+ len(embeds),
+ )
+
+ if isinstance(audio_embeds, str):
+ base64_size = len(audio_embeds)
+ logger.info(
+ "🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
+ base64_size,
+ base64_size / 1024,
+ )
+ embedding = self._connector.fetch_audio_embedding(audio_embeds)
+ future.set_result(embedding)
+ logger.info(
+ "🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
+ embedding.shape,
+ embedding.dtype,
+ )
+
+ if audio_embeds is None:
+ logger.info("🎵 Audio embeds is None (UUID-only reference)")
+ future.set_result(None)
+
+ placeholder = self._tracker.add("audio_embeds", future, uuid)
+ self._add_placeholder("audio", placeholder)
+ logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)
+
def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
@@ -1132,6 +1262,7 @@ def _get_full_multimodal_text_prompt(
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
+_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
@@ -1152,9 +1283,11 @@ MM_PARSER_MAP: dict[
"text": lambda part: _TextParser(part).get("text", None),
"thinking": lambda part: _ThinkParser(part).get("thinking", None),
"input_text": lambda part: _TextParser(part).get("text", None),
+ "output_text": lambda part: _TextParser(part).get("text", None),
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
+ "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
@@ -1223,8 +1356,17 @@ def _parse_chat_message_content_mm_part(
)
image_embeds = image_params.get("image_embeds", None)
return "image_embeds", image_embeds
+ if "audio_embeds" in part:
+ # "audio_embeds" could be None if UUID is provided.
+ audio_params = cast( # type: ignore[assignment]
+ ChatCompletionContentPartAudioEmbedsParam, part
+ )
+ audio_embeds = audio_params.get("audio_embeds", None)
+ return "audio_embeds", audio_embeds
if "audio_url" in part:
- audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
+ audio_params = cast( # type: ignore[assignment]
+ CustomChatCompletionContentSimpleAudioParam, part
+ )
audio_url = audio_params.get("audio_url", None)
if isinstance(audio_url, dict):
# Can potentially happen if user provides a uuid
@@ -1322,7 +1464,7 @@ def _parse_chat_message_content_part(
)
return None
- if part_type in ("text", "input_text", "refusal", "thinking"):
+ if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
str_content = cast(str, content)
if wrap_dicts:
return {"type": "text", "text": str_content}
@@ -1348,6 +1490,10 @@ def _parse_chat_message_content_part(
content = cast(str | dict[str, str], content) if content is not None else None
mm_parser.parse_image_embeds(content, uuid)
modality = "image"
+ elif part_type == "audio_embeds":
+ content = cast(str | dict[str, str], content) if content is not None else None
+ mm_parser.parse_audio_embeds(content, uuid)
+ modality = "audio"
elif part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content, uuid)
@@ -1437,7 +1583,8 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
for item in message["tool_calls"]:
# if arguments is None or empty string, set to {}
if content := item["function"].get("arguments"):
- item["function"]["arguments"] = json.loads(content)
+ if not isinstance(content, (dict, list)):
+ item["function"]["arguments"] = json.loads(content)
else:
item["function"]["arguments"] = {}
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index b0786bd355aa6..848916dbd8763 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -339,7 +339,6 @@ class LLM:
log_non_default_args(engine_args)
- # Create the Engine (autoselects V0 vs V1)
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
)
@@ -466,7 +465,7 @@ class LLM:
):
return lora_request
- if not isinstance(prompts, Sequence):
+ if not isinstance(prompts, Sequence) or isinstance(prompts, str):
prompts = [prompts]
optional_loras = (
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 3974f45a7135c..70174250ceabe 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -394,6 +394,84 @@ async def get_server_load_metrics(request: Request):
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
+@router.post("/pause")
+async def pause_generation(
+ raw_request: Request,
+ wait_for_inflight_requests: bool = Query(False),
+ clear_cache: bool = Query(True),
+) -> JSONResponse:
+ """Pause generation requests to allow weight updates.
+
+ Args:
+ wait_for_inflight_requests: When ``True`` waits for in-flight
+ requests to finish before pausing. When ``False`` (default),
+ aborts any in-flight requests immediately.
+ clear_cache: Whether to clear KV/prefix caches after draining.
+ """
+
+ engine = engine_client(raw_request)
+
+ try:
+ await engine.pause_generation(
+ wait_for_inflight_requests=wait_for_inflight_requests,
+ clear_cache=clear_cache,
+ )
+ return JSONResponse(
+ content={"status": "paused"},
+ status_code=HTTPStatus.OK.value,
+ )
+
+ except ValueError as err:
+ return JSONResponse(
+ content={"error": str(err)},
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ )
+ except Exception as err: # pragma: no cover - defensive
+ logger.exception("Failed to pause generation")
+ return JSONResponse(
+ content={"error": f"Failed to pause generation: {err}"},
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ )
+
+
+@router.post("/resume")
+async def resume_generation(raw_request: Request) -> JSONResponse:
+ """Resume generation after a pause."""
+
+ engine = engine_client(raw_request)
+
+ try:
+ await engine.resume_generation()
+ return JSONResponse(
+ content={"status": "resumed"},
+ status_code=HTTPStatus.OK.value,
+ )
+ except Exception as err: # pragma: no cover - defensive
+ logger.exception("Failed to resume generation")
+ return JSONResponse(
+ content={"error": f"Failed to resume generation: {err}"},
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ )
+
+
+@router.get("/is_paused")
+async def is_paused(raw_request: Request) -> JSONResponse:
+ """Return the current pause status."""
+
+ engine = engine_client(raw_request)
+
+ try:
+ paused = await engine.is_paused()
+ except Exception as err: # pragma: no cover - defensive
+ logger.exception("Failed to fetch pause status")
+ return JSONResponse(
+ content={"error": f"Failed to fetch pause status: {err}"},
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ )
+
+ return JSONResponse(content={"is_paused": paused})
+
+
@router.post(
"/tokenize",
dependencies=[Depends(validate_json_request)],
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 65bd15ba387b9..98a385a1dcd5f 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -377,7 +377,7 @@ class ResponsesRequest(OpenAIBaseModel):
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
- "to 256 bit). Not supported by vLLM engine V0."
+ "to 256 bit)."
),
)
@@ -559,13 +559,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
) = "none"
reasoning_effort: Literal["low", "medium", "high"] | None = None
include_reasoning: bool = True
+ parallel_tool_calls: bool | None = True
- # NOTE this will be ignored by vLLM -- the model determines the behavior
- parallel_tool_calls: bool | None = False
+ # NOTE this will be ignored by vLLM
user: str | None = None
# --8<-- [start:chat-completion-sampling-params]
- best_of: int | None = None
use_beam_search: bool = False
top_k: int | None = None
min_p: float | None = None
@@ -653,62 +652,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
- guided_json: str | dict | BaseModel | None = Field(
- default=None,
- description=(
- "`guided_json` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `json` to `structured_outputs` instead."
- ),
- )
- guided_regex: str | None = Field(
- default=None,
- description=(
- "`guided_regex` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `regex` to `structured_outputs` instead."
- ),
- )
- guided_choice: list[str] | None = Field(
- default=None,
- description=(
- "`guided_choice` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `choice` to `structured_outputs` instead."
- ),
- )
- guided_grammar: str | None = Field(
- default=None,
- description=(
- "`guided_grammar` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `grammar` to `structured_outputs` instead."
- ),
- )
- structural_tag: str | None = Field(
- default=None,
- description=(
- "`structural_tag` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `structural_tag` to `structured_outputs` instead."
- ),
- )
- guided_decoding_backend: str | None = Field(
- default=None,
- description=(
- "`guided_decoding_backend` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please remove it from your request."
- ),
- )
- guided_whitespace_pattern: str | None = Field(
- default=None,
- description=(
- "`guided_whitespace_pattern` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `whitespace_pattern` to `structured_outputs` instead."
- ),
- )
priority: int = Field(
default=0,
description=(
@@ -718,7 +661,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -764,7 +707,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
- "to 256 bit). Not supported by vLLM engine V0."
+ "to 256 bit)."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
@@ -842,20 +785,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
- # Forward deprecated guided_* parameters to structured_outputs
- if self.structured_outputs is None:
- kwargs = dict[str, Any](
- json=self.guided_json,
- regex=self.guided_regex,
- choice=self.guided_choice,
- grammar=self.guided_grammar,
- whitespace_pattern=self.guided_whitespace_pattern,
- structural_tag=self.structural_tag,
- )
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if len(kwargs) > 0:
- self.structured_outputs = StructuredOutputsParams(**kwargs)
-
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
@@ -864,24 +793,23 @@ class ChatCompletionRequest(OpenAIBaseModel):
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
- if response_format is not None:
- if response_format.type == "json_object":
- self.structured_outputs.json_object = True
- elif response_format.type == "json_schema":
- json_schema = response_format.json_schema
- assert json_schema is not None
- self.structured_outputs.json = json_schema.json_schema
- elif response_format.type == "structural_tag":
- structural_tag = response_format
- assert structural_tag is not None and isinstance(
- structural_tag,
- (
- LegacyStructuralTagResponseFormat,
- StructuralTagResponseFormat,
- ),
- )
- s_tag_obj = structural_tag.model_dump(by_alias=True)
- self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
+ if response_format.type == "json_object":
+ self.structured_outputs.json_object = True
+ elif response_format.type == "json_schema":
+ json_schema = response_format.json_schema
+ assert json_schema is not None
+ self.structured_outputs.json = json_schema.json_schema
+ elif response_format.type == "structural_tag":
+ structural_tag = response_format
+ assert structural_tag is not None and isinstance(
+ structural_tag,
+ (
+ LegacyStructuralTagResponseFormat,
+ StructuralTagResponseFormat,
+ ),
+ )
+ s_tag_obj = structural_tag.model_dump(by_alias=True)
+ self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@@ -889,7 +817,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
- best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
@@ -1088,7 +1015,6 @@ class CompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/completions/create
model: str | None = None
prompt: list[int] | list[list[int]] | str | list[str] | None = None
- best_of: int | None = None
echo: bool | None = False
frequency_penalty: float | None = 0.0
logit_bias: dict[str, float] | None = None
@@ -1143,58 +1069,6 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description="Additional kwargs for structured outputs",
)
- guided_json: str | dict | BaseModel | None = Field(
- default=None,
- description=(
- "`guided_json` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `json` to `structured_outputs` instead."
- ),
- )
- guided_regex: str | None = Field(
- default=None,
- description=(
- "`guided_regex` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `regex` to `structured_outputs` instead."
- ),
- )
- guided_choice: list[str] | None = Field(
- default=None,
- description=(
- "`guided_choice` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `choice` to `structured_outputs` instead."
- ),
- )
- guided_grammar: str | None = Field(
- default=None,
- description=(
- "`guided_grammar` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `grammar` to `structured_outputs` instead."
- ),
- )
- structural_tag: str | None = Field(
- default=None,
- description=("If specified, the output will follow the structural tag schema."),
- )
- guided_decoding_backend: str | None = Field(
- default=None,
- description=(
- "`guided_decoding_backend` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please remove it from your request."
- ),
- )
- guided_whitespace_pattern: str | None = Field(
- default=None,
- description=(
- "`guided_whitespace_pattern` is deprecated. "
- "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
- "Please pass `whitespace_pattern` to `structured_outputs` instead."
- ),
- )
priority: int = Field(
default=0,
description=(
@@ -1204,7 +1078,7 @@ class CompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -1252,7 +1126,7 @@ class CompletionRequest(OpenAIBaseModel):
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
- "to 256 bit). Not supported by vLLM engine V0."
+ "to 256 bit)."
),
)
@@ -1339,35 +1213,31 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
- guided_json_object = None
- if self.response_format is not None:
- if self.response_format.type == "json_object":
- guided_json_object = True
- elif self.response_format.type == "json_schema":
- json_schema = self.response_format.json_schema
+ response_format = self.response_format
+ if response_format is not None:
+ # If structured outputs wasn't already enabled,
+ # we must enable it for these features to work
+ if self.structured_outputs is None:
+ self.structured_outputs = StructuredOutputsParams()
+
+ # Set structured output params for response format
+ if response_format.type == "json_object":
+ self.structured_outputs.json_object = True
+ elif response_format.type == "json_schema":
+ json_schema = response_format.json_schema
assert json_schema is not None
- self.guided_json = json_schema.json_schema
- elif self.response_format.type == "structural_tag":
- structural_tag = self.response_format
+ self.structured_outputs.json = json_schema.json_schema
+ elif response_format.type == "structural_tag":
+ structural_tag = response_format
assert structural_tag is not None and isinstance(
- structural_tag, StructuralTagResponseFormat
+ structural_tag,
+ (
+ LegacyStructuralTagResponseFormat,
+ StructuralTagResponseFormat,
+ ),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
- self.structural_tag = json.dumps(s_tag_obj)
-
- # Forward deprecated guided_* parameters to structured_outputs
- if self.structured_outputs is None:
- kwargs = dict[str, Any](
- json=self.guided_json,
- json_object=guided_json_object,
- regex=self.guided_regex,
- choice=self.guided_choice,
- grammar=self.guided_grammar,
- whitespace_pattern=self.guided_whitespace_pattern,
- )
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if len(kwargs) > 0:
- self.structured_outputs = StructuredOutputsParams(**kwargs)
+ self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
@@ -1375,7 +1245,6 @@ class CompletionRequest(OpenAIBaseModel):
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
- best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
@@ -1506,7 +1375,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -1601,7 +1470,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -2023,7 +1892,7 @@ class ClassificationCompletionRequest(OpenAIBaseModel):
),
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -2114,7 +1983,7 @@ class ClassificationChatRequest(OpenAIBaseModel):
)
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -3225,7 +3094,7 @@ class TranslationResponseVerbose(OpenAIBaseModel):
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
@@ -3282,7 +3151,7 @@ class GenerateResponseChoice(BaseModel):
class GenerateResponse(BaseModel):
request_id: str = Field(
- default_factory=lambda: f"{random_uuid()}",
+ default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 59e1c8d531793..9a7051e0920af 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
+from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
@@ -273,6 +274,11 @@ class OpenAIServingChat(OpenAIServing):
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i])
+ # If we are creating sub requests for multiple prompts, ensure that they
+ # have unique request ids.
+ sub_request_id = (
+ request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
+ )
if self.default_sampling_params is None:
self.default_sampling_params = {}
@@ -301,7 +307,7 @@ class OpenAIServingChat(OpenAIServing):
)
self._log_inputs(
- request_id,
+ sub_request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
@@ -316,13 +322,14 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
- request_id=request_id,
+ request_id=sub_request_id,
params=sampling_params,
lora_request=lora_request,
+ trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
- request_id,
+ sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
@@ -333,7 +340,7 @@ class OpenAIServingChat(OpenAIServing):
generator = self.engine_client.generate(
engine_request,
sampling_params,
- request_id,
+ sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
@@ -1200,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing):
finish_reason_sent[i] = True
+ choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
@@ -1525,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing):
as_list(output.token_ids) if request.return_token_ids else None
),
)
+ choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
choices.append(choice_data)
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index a114b77ebc16b..9681aa8c71e6d 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -216,6 +216,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
+ trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
@@ -249,14 +250,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
- # Similar to the OpenAI API, when n != best_of, we do not stream the
- # results. Noting that best_of is only supported in V0. In addition,
- # we do not stream the results when use beam search.
- stream = (
- request.stream
- and (request.best_of is None or request.n == request.best_of)
- and not request.use_beam_search
- )
+ # We do not stream the results when using beam search.
+ stream = request.stream and not request.use_beam_search
# Streaming response
if stream:
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index c50b0c4a23e17..d9feee917ff4e 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
+import numpy as np
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
@@ -295,11 +296,7 @@ class OpenAIServing:
parser = None
if not enable_auto_tools or tool_parser_name is None:
return parser
- logger.info(
- '"auto" tool choice has been enabled please note that while'
- " the parallel_tool_calls client option is preset for "
- "compatibility reasons, it will be ignored."
- )
+ logger.info('"auto" tool choice has been enabled.')
try:
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
@@ -342,6 +339,7 @@ class OpenAIServing:
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
+ trace_headers: Mapping[str, str] | None = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
@@ -389,8 +387,9 @@ class OpenAIServing:
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
+ logprobs_num = 2 * beam_width
beam_search_params = SamplingParams(
- logprobs=2 * beam_width,
+ logprobs=logprobs_num,
max_tokens=1,
temperature=temperature,
)
@@ -435,6 +434,7 @@ class OpenAIServing:
beam_search_params,
request_id_item,
lora_request=lora_req,
+ trace_headers=trace_headers,
)
)
)
@@ -443,40 +443,75 @@ class OpenAIServing:
output = [x[0] for x in await asyncio.gather(*tasks)]
new_beams = []
- for i, current_beam in enumerate(all_beams):
- result = output[i]
-
+ # Store all new tokens generated by beam
+ all_beams_token_id = []
+ # Store the cumulative probability of all tokens
+ # generated by beam search
+ all_beams_logprob = []
+ # Iterate through all beam inference results
+ for i, result in enumerate(output):
+ current_beam = all_beams[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
- for token_id, logprob_obj in logprobs.items():
- if token_id == eos_token_id and not ignore_eos:
- completed.append(
- BeamSearchSequence(
- tokens=current_beam.tokens + [token_id]
- if include_stop_str_in_output
- else current_beam.tokens,
- logprobs=current_beam.logprobs + [logprobs],
- cum_logprob=current_beam.cum_logprob
- + logprob_obj.logprob,
- finish_reason="stop",
- stop_reason=eos_token_id,
- )
- )
- else:
- new_beams.append(
- BeamSearchSequence(
- tokens=current_beam.tokens + [token_id],
- logprobs=current_beam.logprobs + [logprobs],
- lora_request=current_beam.lora_request,
- cum_logprob=current_beam.cum_logprob
- + logprob_obj.logprob,
- multi_modal_data=current_beam.multi_modal_data,
- mm_processor_kwargs=current_beam.mm_processor_kwargs,
- )
- )
+ all_beams_token_id.extend(list(logprobs.keys()))
+ all_beams_logprob.extend(
+ [
+ current_beam.cum_logprob + obj.logprob
+ for obj in logprobs.values()
+ ]
+ )
- sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
- all_beams = sorted_beams[:beam_width]
+ # Handle the token for the end of sentence (EOS)
+ all_beams_token_id = np.array(all_beams_token_id)
+ all_beams_logprob = np.array(all_beams_logprob)
+
+ if not ignore_eos:
+ # Get the index position of eos token in all generated results
+ eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
+ for idx in eos_idx:
+ current_beam = all_beams[idx // logprobs_num]
+ result = output[idx // logprobs_num]
+ assert result.outputs[0].logprobs is not None
+ logprobs_entry = result.outputs[0].logprobs[0]
+ completed.append(
+ BeamSearchSequence(
+ tokens=current_beam.tokens + [eos_token_id]
+ if include_stop_str_in_output
+ else current_beam.tokens,
+ logprobs=current_beam.logprobs + [logprobs_entry],
+ cum_logprob=float(all_beams_logprob[idx]),
+ finish_reason="stop",
+ stop_reason=eos_token_id,
+ )
+ )
+ # After processing, set the log probability of the eos condition
+ # to negative infinity.
+ all_beams_logprob[eos_idx] = -np.inf
+
+ # Processing non-EOS tokens
+ # Get indices of the top beam_width probabilities
+ topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
+ :beam_width
+ ]
+
+ for idx in topn_idx:
+ current_beam = all_beams[idx // logprobs_num]
+ result = output[idx // logprobs_num]
+ token_id = int(all_beams_token_id[idx])
+ assert result.outputs[0].logprobs is not None
+ logprobs_entry = result.outputs[0].logprobs[0]
+ new_beams.append(
+ BeamSearchSequence(
+ tokens=current_beam.tokens + [token_id],
+ logprobs=current_beam.logprobs + [logprobs_entry],
+ lora_request=current_beam.lora_request,
+ cum_logprob=float(all_beams_logprob[idx]),
+ multi_modal_data=current_beam.multi_modal_data,
+ mm_processor_kwargs=current_beam.mm_processor_kwargs,
+ )
+ )
+
+ all_beams = new_beams
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
@@ -1203,16 +1238,19 @@ class OpenAIServing:
):
prompt_text, _, _ = self._get_prompt_components(request_prompt)
orig_priority = priority
+ sub_request = 0
while True:
+ # Ensure that each sub-request has a unique request id.
+ sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
- request_id,
+ sub_request_id,
request_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs(
- request_id,
+ sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
@@ -1223,7 +1261,7 @@ class OpenAIServing:
generator = self.engine_client.generate(
engine_request,
sampling_params,
- request_id,
+ sub_request_id,
lora_request=lora_request,
priority=priority,
prompt_text=prompt_text,
@@ -1256,6 +1294,7 @@ class OpenAIServing:
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
# OPTIMIZATION
priority = orig_priority - 1
+ sub_request += 1
def _get_prompt_components(
self,
@@ -1306,11 +1345,12 @@ class OpenAIServing:
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
- default = default or random_uuid()
- if raw_request is None:
- return default
+ if raw_request is not None and (
+ (req_id := raw_request.headers.get("X-Request-Id")) is not None
+ ):
+ return req_id
- return raw_request.headers.get("X-Request-Id", default)
+ return random_uuid() if default is None else default
@staticmethod
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py
index 06efb43ecb7b8..f546dbda7fef5 100644
--- a/vllm/entrypoints/openai/serving_responses.py
+++ b/vllm/entrypoints/openai/serving_responses.py
@@ -94,7 +94,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import (
- construct_chat_message_with_tool_call,
+ construct_input_messages,
convert_tool_responses_to_completions_format,
extract_tool_types,
)
@@ -504,7 +504,12 @@ class OpenAIServingResponses(OpenAIServing):
for tool in request.tools
]
# Construct the input messages.
- messages = self._construct_input_messages(request, prev_response)
+ messages = construct_input_messages(
+ request_instructions=request.instructions,
+ request_input=request.input,
+ prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
+ prev_response_output=prev_response.output if prev_response else None,
+ )
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
@@ -869,47 +874,6 @@ class OpenAIServingResponses(OpenAIServing):
output_items.extend(last_items)
return output_items
- def _construct_input_messages(
- self,
- request: ResponsesRequest,
- prev_response: ResponsesResponse | None = None,
- ) -> list[ChatCompletionMessageParam]:
- messages: list[ChatCompletionMessageParam] = []
- if request.instructions:
- messages.append(
- {
- "role": "system",
- "content": request.instructions,
- }
- )
-
- # Prepend the conversation history.
- if prev_response is not None:
- # Add the previous messages.
- prev_msg = self.msg_store[prev_response.id]
- messages.extend(prev_msg)
-
- # Add the previous output.
- for output_item in prev_response.output:
- # NOTE: We skip the reasoning output.
- if isinstance(output_item, ResponseOutputMessage):
- for content in output_item.content:
- messages.append(
- {
- "role": "assistant",
- "content": content.text,
- }
- )
-
- # Append the new input.
- # Responses API supports simple text inputs without chat format.
- if isinstance(request.input, str):
- messages.append({"role": "user", "content": request.input})
- else:
- for item in request.input:
- messages.append(construct_chat_message_with_tool_call(item))
- return messages
-
def _construct_harmony_system_input_message(
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str]
) -> OpenAIHarmonyMessage:
diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py
index b9b9b1ab30ad8..3dece07748cc4 100644
--- a/vllm/entrypoints/openai/speech_to_text.py
+++ b/vllm/entrypoints/openai/speech_to_text.py
@@ -201,10 +201,10 @@ class OpenAISpeechToText(OpenAIServing):
self.engine_client.generate(
prompt,
sampling_params,
- request_id,
+ f"{request_id}_{i}",
lora_request=lora_request,
)
- for prompt in prompts
+ for i, prompt in enumerate(prompts)
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
index 120e63b929b16..389e9754b34da 100644
--- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
@@ -78,7 +78,7 @@ class Glm4MoeModelToolParser(ToolParser):
.get("type", None)
)
return arg_type == "string"
- logger.warning("No tool named '%s'.", tool_name)
+ logger.debug("No tool named '%s'.", tool_name)
return False
def _deserialize(value: str) -> Any:
diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
index 02fc9b8a4d34e..e1fe6e90dfd0b 100644
--- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
@@ -9,6 +9,7 @@ import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
+import vllm.envs as envs
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser):
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
0
]
- # Updated regex to match multiple JSONs separated by semicolons
- # This pattern is more robust and can handle nested JSON objects
- self.tool_call_regex = re.compile(
- r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
- re.DOTALL,
- )
+ # Simple regex to find opening braces - we'll use JSON decoder for parsing
+ # This handles arbitrary nesting depth correctly
+ self.tool_call_start_regex = re.compile(r"\{")
+ self.json_decoder = json.JSONDecoder()
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
@@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser):
tools_called=False, tool_calls=[], content=model_output
)
- # Find JSON object(s) in the text using regex
- match = self.tool_call_regex.search(model_output)
- if not match:
+ # Keep track of the end index of the last parsed JSON object
+ # so we don't parse inner brackets
+ end_index = -1
+ tool_calls: list[ToolCall] = []
+
+ try:
+ for match in self.tool_call_start_regex.finditer(
+ model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
+ ):
+ start_index = match.start()
+ # Skip if this brace is inside a previously parsed JSON object
+ if start_index <= end_index:
+ continue
+
+ try:
+ obj, json_end_index = self.json_decoder.raw_decode(
+ model_output[start_index:]
+ )
+ end_index = start_index + json_end_index
+
+ # raise KeyError if missing
+ name = obj["name"]
+ arguments_or_params = (
+ obj["arguments"] if "arguments" in obj else obj["parameters"]
+ )
+
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=name,
+ # function call args are JSON but as a string
+ arguments=json.dumps(
+ arguments_or_params, ensure_ascii=False
+ ),
+ ),
+ )
+ )
+ except KeyError as e:
+ # Missing required key
+ missing_key = str(e).strip("'\"")
+ logger.exception(
+ "Couldn't extract tool call from JSON response. "
+ "Required key '%s' not present. "
+ "Returning output in content with empty tool calls.",
+ missing_key,
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ except Exception:
+ # Any other error during parsing
+ logger.exception(
+ "Error in extracting tool call from response. "
+ "Returning output in content with empty tool calls"
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ except TimeoutError:
+ logger.warning("Regex timeout occurred when matching tool call pattern.")
+ logger.debug(
+ "Regex timeout occurred when matching user input: %s", model_output
+ )
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
- try:
- json_str = match.group(0)
- # Split by semicolon and strip whitespace
- json_objects = [obj.strip() for obj in json_str.split(";")]
-
- tool_calls: list[ToolCall] = []
- for json_obj in json_objects:
- if not json_obj: # Skip empty strings
- continue
- obj = json.loads(json_obj)
- tool_calls.append(
- ToolCall(
- type="function",
- function=FunctionCall(
- name=obj["name"],
- # function call args are JSON but as a string
- arguments=json.dumps(
- obj["arguments"]
- if "arguments" in obj
- else obj["parameters"],
- ensure_ascii=False,
- ),
- ),
- )
- )
-
+ # If we have valid tool calls, return them normally
+ if tool_calls:
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)
- except Exception:
- logger.exception("Error in extracting tool call from response.")
- # return information to just treat the tool call as regular JSON
- return ExtractedToolCallInformation(
- tools_called=False, tool_calls=[], content=model_output
- )
+ # No valid tool calls found
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
def extract_tool_calls_streaming(
self,
diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
index 26261c0065ead..9d4c079eba188 100644
--- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
@@ -128,7 +128,7 @@ class Qwen3CoderToolParser(ToolParser):
return params
else:
return {}
- logger.warning("Tool '%s' is not defined in the tools list.", func_name)
+ logger.debug("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value(
@@ -141,7 +141,7 @@ class Qwen3CoderToolParser(ToolParser):
if param_name not in param_config:
if param_config != {}:
- logger.warning(
+ logger.debug(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.",
@@ -169,7 +169,7 @@ class Qwen3CoderToolParser(ToolParser):
try:
return int(param_value)
except (ValueError, TypeError):
- logger.warning(
+ logger.debug(
"Parsed value '%s' of parameter '%s' is not an "
"integer in tool '%s', degenerating to string.",
param_value,
@@ -186,7 +186,7 @@ class Qwen3CoderToolParser(ToolParser):
else int(float_param_value)
)
except (ValueError, TypeError):
- logger.warning(
+ logger.debug(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', degenerating to string.",
param_value,
@@ -197,7 +197,7 @@ class Qwen3CoderToolParser(ToolParser):
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
- logger.warning(
+ logger.debug(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` or `false`) in tool '%s', degenerating to "
"false.",
@@ -216,7 +216,7 @@ class Qwen3CoderToolParser(ToolParser):
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
- logger.warning(
+ logger.debug(
"Parsed value '%s' of parameter '%s' cannot be "
"parsed with json.loads in tool '%s', will try "
"other methods to parse it.",
@@ -227,7 +227,7 @@ class Qwen3CoderToolParser(ToolParser):
try:
param_value = ast.literal_eval(param_value) # safer
except (ValueError, SyntaxError, TypeError):
- logger.warning(
+ logger.debug(
"Parsed value '%s' of parameter '%s' cannot be "
"converted via Python `ast.literal_eval()` in tool "
"'%s', degenerating to string.",
diff --git a/vllm/entrypoints/openai/utils.py b/vllm/entrypoints/openai/utils.py
new file mode 100644
index 0000000000000..6f37f6adff4c2
--- /dev/null
+++ b/vllm/entrypoints/openai/utils.py
@@ -0,0 +1,37 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import TypeVar
+
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+)
+
+# Used internally
+_ChatCompletionResponseChoiceT = TypeVar(
+ "_ChatCompletionResponseChoiceT",
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+)
+
+
+def maybe_filter_parallel_tool_calls(
+ choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
+) -> _ChatCompletionResponseChoiceT:
+ """Filter to first tool call only when parallel_tool_calls is False."""
+
+ if request.parallel_tool_calls:
+ return choice
+
+ if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
+ choice.message.tool_calls = choice.message.tool_calls[:1]
+ elif (
+ isinstance(choice, ChatCompletionResponseStreamChoice)
+ and choice.delta.tool_calls
+ ):
+ choice.delta.tool_calls = [
+ tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
+ ]
+
+ return choice
diff --git a/vllm/entrypoints/responses_utils.py b/vllm/entrypoints/responses_utils.py
index d966f58804b67..b02c43c7f8246 100644
--- a/vllm/entrypoints/responses_utils.py
+++ b/vllm/entrypoints/responses_utils.py
@@ -9,7 +9,9 @@ from openai.types.chat import (
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as FunctionCallTool,
)
-from openai.types.responses import ResponseFunctionToolCall
+from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
+from openai.types.responses.response_output_message import ResponseOutputMessage
+from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
from vllm import envs
@@ -19,6 +21,49 @@ from vllm.entrypoints.openai.protocol import (
)
+def construct_input_messages(
+ *,
+ request_instructions: str | None = None,
+ request_input: str | list[ResponseInputOutputItem],
+ prev_msg: list[ChatCompletionMessageParam] | None = None,
+ prev_response_output: list[ResponseOutputItem] | None = None,
+):
+ messages: list[ChatCompletionMessageParam] = []
+ if request_instructions:
+ messages.append(
+ {
+ "role": "system",
+ "content": request_instructions,
+ }
+ )
+
+ # Prepend the conversation history.
+ if prev_msg is not None:
+ # Add the previous messages.
+ messages.extend(prev_msg)
+ if prev_response_output is not None:
+ # Add the previous output.
+ for output_item in prev_response_output:
+ # NOTE: We skip the reasoning output.
+ if isinstance(output_item, ResponseOutputMessage):
+ for content in output_item.content:
+ messages.append(
+ {
+ "role": "assistant",
+ "content": content.text,
+ }
+ )
+
+ # Append the new input.
+ # Responses API supports simple text inputs without chat format.
+ if isinstance(request_input, str):
+ messages.append({"role": "user", "content": request_input})
+ else:
+ for item in request_input:
+ messages.append(construct_chat_message_with_tool_call(item))
+ return messages
+
+
def construct_chat_message_with_tool_call(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
@@ -37,6 +82,18 @@ def construct_chat_message_with_tool_call(
)
],
)
+ elif isinstance(item, ResponseReasoningItem):
+ reasoning_content = ""
+ if item.encrypted_content:
+ raise ValueError("Encrypted content is not supported.")
+ if len(item.summary) == 1:
+ reasoning_content = item.summary[0].text
+ elif item.content and len(item.content) == 1:
+ reasoning_content = item.content[0].text
+ return {
+ "role": "assistant",
+ "reasoning": reasoning_content,
+ }
elif item.get("type") == "function_call_output":
# Append the function call output as a tool message.
return ChatCompletionToolMessageParam(
diff --git a/vllm/env_override.py b/vllm/env_override.py
index 14dae2850c354..9ae1af3af46cf 100644
--- a/vllm/env_override.py
+++ b/vllm/env_override.py
@@ -95,7 +95,7 @@ def memory_plan_reuse_patched(self):
# ===================================================
# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
# fix inductor partition + attention-nvfp4 quant fusion, tested in
-# `tests/compile/test_fusions_e2e.py::test_attn_quant`.
+# `tests/compile/distributed/test_fusions_e2e.py::test_attn_quant`.
# For more context, see https://github.com/pytorch/pytorch/pull/165815.
diff --git a/vllm/envs.py b/vllm/envs.py
index 6d92d5afee501..56558548d3981 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
-import hashlib
import json
+import logging
import os
import sys
import tempfile
@@ -42,6 +42,8 @@ if TYPE_CHECKING:
VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_STREAM: str = "ext://sys.stdout"
VLLM_LOGGING_CONFIG_PATH: str | None = None
+ VLLM_LOGGING_COLOR: str = "auto"
+ NO_COLOR: bool = False
VLLM_LOG_STATS_INTERVAL: float = 10.0
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: str | None = None
@@ -53,7 +55,7 @@ if TYPE_CHECKING:
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
- VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
+ VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
@@ -90,11 +92,14 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: str | None = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
+ VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
+ VLLM_PROFILER_DELAY_ITERS: int = 0
+ VLLM_PROFILER_MAX_ITERS: int = 0
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
@@ -157,7 +162,9 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
- VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
+ VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
+ "latency"
+ )
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -224,6 +231,7 @@ if TYPE_CHECKING:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
+ VLLM_USE_V2_MODEL_RUNNER: bool = False
def get_default_cache_root():
@@ -426,6 +434,8 @@ def get_vllm_port() -> int | None:
# --8<-- [start:env-vars-definition]
+logger = logging.getLogger(__name__)
+
environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
@@ -612,6 +622,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_STREAM": lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"),
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
"VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
+ # Controls colored logging output. Options: "auto" (default, colors when terminal),
+ # "1" (always use colors), "0" (never use colors)
+ "VLLM_LOGGING_COLOR": lambda: os.getenv("VLLM_LOGGING_COLOR", "auto"),
+ # Standard unix flag for disabling ANSI color codes
+ "NO_COLOR": lambda: os.getenv("NO_COLOR", "0") != "0",
# If set, vllm will log stats at this interval in seconds
# If not set, vllm will log stats every 10 seconds.
"VLLM_LOG_STATS_INTERVAL": lambda: val
@@ -625,7 +640,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Example options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
- # - "XFORMERS": use XFormers
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
@@ -770,7 +784,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
- os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")
+ os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
),
# Control whether to use fused MoE activation chunking. Current chunking
# logic is incompatible with torch.compile and causes IMA. See issue
@@ -861,6 +875,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
),
+ # Disable torch profiling of the AsyncLLMEngine process.
+ # If set to 1, will not profile the engine process.
+ "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
+ os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
+ ),
+ # Delay number of iterations before starting profiling when using
+ # the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
+ "VLLM_PROFILER_DELAY_ITERS": lambda: int(
+ os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
+ ),
+ # Maximum number of iterations to profile when using the torch/torch CUDA profiler.
+ # If set to 0, will not limit the number of iterations.
+ "VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
# If set, allow loading or unloading lora adapters in runtime,
@@ -1236,7 +1263,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
- "VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
+ "VLLM_FLASHINFER_MOE_BACKEND",
+ "latency",
+ ["throughput", "latency", "masked_gemm"],
),
# Control the workspace buffer size for the FlashInfer backend.
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
@@ -1261,7 +1290,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available
# strategies.
- # Cutstom routing strategies can be registered by
+ # Custom routing strategies can be registered by
# RoutingSimulator.register_strategy()
# Note: custom strategies may not produce correct model outputs
"VLLM_MOE_ROUTING_SIMULATION_STRATEGY": lambda: os.environ.get(
@@ -1493,6 +1522,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
+ # Flag to enable v2 model runner.
+ "VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
+ int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
+ ),
}
# --8<-- [end:env-vars-definition]
@@ -1540,85 +1573,90 @@ def is_set(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
-def compute_hash() -> str:
- """
- WARNING: Whenever a new key is added to this environment
- variables, ensure that it is included in the factors list if
- it affects the computation graph. For example, different values
- of VLLM_PP_LAYER_PARTITION will generate different computation
- graphs, so it is included in the factors list. The env vars that
- affect the choice of different kernels or attention backends should
- also be included in the factors list.
- """
+def compile_factors() -> dict[str, object]:
+ """Return env vars used for torch.compile cache keys.
- # The values of envs may affects the computation graph.
- # TODO(DefTruth): hash all environment variables?
- # for key in environment_variables:
- # factorize(key)
- environment_variables_to_hash = [
- "VLLM_PP_LAYER_PARTITION",
- "VLLM_MLA_DISABLE",
- "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
- "VLLM_USE_TRITON_AWQ",
- "VLLM_DP_RANK",
- "VLLM_DP_SIZE",
- "VLLM_USE_STANDALONE_COMPILE",
- "VLLM_FUSED_MOE_CHUNK_SIZE",
- "VLLM_FLASHINFER_MOE_BACKEND",
- "VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
- "VLLM_ATTENTION_BACKEND",
- "VLLM_USE_FLASHINFER_SAMPLER",
- "VLLM_DISABLED_KERNELS",
- "VLLM_USE_DEEP_GEMM",
- "VLLM_MOE_USE_DEEP_GEMM",
- "VLLM_USE_DEEP_GEMM_E8M0",
- "VLLM_USE_FUSED_MOE_GROUPED_TOPK",
- "VLLM_USE_FLASHINFER_MOE_FP16",
- "VLLM_USE_FLASHINFER_MOE_FP8",
- "VLLM_USE_FLASHINFER_MOE_FP4",
- "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
- "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
- "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
- "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE",
- "VLLM_USE_CUDNN_PREFILL",
- "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
- "VLLM_USE_TRTLLM_ATTENTION",
- "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
- "VLLM_ROCM_USE_AITER",
- "VLLM_ROCM_USE_AITER_PAGED_ATTN",
- "VLLM_ROCM_USE_AITER_LINEAR",
- "VLLM_ROCM_USE_AITER_MOE",
- "VLLM_ROCM_USE_AITER_RMSNORM",
- "VLLM_ROCM_USE_AITER_MLA",
- "VLLM_ROCM_USE_AITER_MHA",
- "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
- "VLLM_ROCM_USE_AITER_TRITON_ROPE",
- "VLLM_ROCM_USE_AITER_FP8BMM",
- "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
- "VLLM_ROCM_USE_AITER_TRITON_GEMM",
- "VLLM_ROCM_USE_SKINNY_GEMM",
- "VLLM_ROCM_FP8_PADDING",
- "VLLM_ROCM_MOE_PADDING",
- "VLLM_ROCM_CUSTOM_PAGED_ATTN",
- "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION",
- "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
- "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
- "VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
- "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
- "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
- "VLLM_NVFP4_GEMM_BACKEND",
- "VLLM_USE_FBGEMM",
- "VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE",
- "VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
- ]
- for key in environment_variables_to_hash:
- # if this goes out of sync with environment_variables,
- # it's not a user error, it's a bug
- assert key in environment_variables, (
- "Please update environment_variables_to_hash in envs.py"
- )
+ Start with every known vLLM env var; drop entries in `ignored_factors`;
+ hash everything else. This keeps the cache key aligned across workers."""
- factors = [environment_variables[key]() for key in environment_variables_to_hash]
+ ignored_factors: set[str] = {
+ "MAX_JOBS",
+ "VLLM_RPC_BASE_PATH",
+ "VLLM_USE_MODELSCOPE",
+ "VLLM_RINGBUFFER_WARNING_INTERVAL",
+ "VLLM_DEBUG_DUMP_PATH",
+ "VLLM_PORT",
+ "VLLM_CACHE_ROOT",
+ "LD_LIBRARY_PATH",
+ "VLLM_SERVER_DEV_MODE",
+ "VLLM_DP_MASTER_IP",
+ "VLLM_DP_MASTER_PORT",
+ "VLLM_RANDOMIZE_DP_DUMMY_INPUTS",
+ "VLLM_CI_USE_S3",
+ "VLLM_MODEL_REDIRECT_PATH",
+ "VLLM_HOST_IP",
+ "S3_ACCESS_KEY_ID",
+ "S3_SECRET_ACCESS_KEY",
+ "S3_ENDPOINT_URL",
+ "VLLM_USAGE_STATS_SERVER",
+ "VLLM_NO_USAGE_STATS",
+ "VLLM_DO_NOT_TRACK",
+ "VLLM_LOGGING_LEVEL",
+ "VLLM_LOGGING_PREFIX",
+ "VLLM_LOGGING_STREAM",
+ "VLLM_LOGGING_CONFIG_PATH",
+ "VLLM_LOGGING_COLOR",
+ "VLLM_LOG_STATS_INTERVAL",
+ "VLLM_DEBUG_LOG_API_SERVER_RESPONSE",
+ "VLLM_TUNED_CONFIG_FOLDER",
+ "VLLM_ENGINE_ITERATION_TIMEOUT_S",
+ "VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
+ "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
+ "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
+ "VLLM_SLEEP_WHEN_IDLE",
+ "VLLM_IMAGE_FETCH_TIMEOUT",
+ "VLLM_VIDEO_FETCH_TIMEOUT",
+ "VLLM_AUDIO_FETCH_TIMEOUT",
+ "VLLM_MEDIA_URL_ALLOW_REDIRECTS",
+ "VLLM_MEDIA_LOADING_THREAD_COUNT",
+ "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB",
+ "VLLM_VIDEO_LOADER_BACKEND",
+ "VLLM_MEDIA_CONNECTOR",
+ "VLLM_ASSETS_CACHE",
+ "VLLM_ASSETS_CACHE_MODEL_CLEAN",
+ "VLLM_MM_INPUT_CACHE_GIB",
+ "VLLM_WORKER_MULTIPROC_METHOD",
+ "VLLM_ENABLE_V1_MULTIPROCESSING",
+ "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
+ "VLLM_CPU_KVCACHE_SPACE",
+ "VLLM_CPU_OMP_THREADS_BIND",
+ "VLLM_CPU_NUM_OF_RESERVED_CPU",
+ "VLLM_CPU_MOE_PREPACK",
+ "VLLM_CPU_SGL_KERNEL",
+ "VLLM_TEST_FORCE_LOAD_FORMAT",
+ "LOCAL_RANK",
+ "CUDA_VISIBLE_DEVICES",
+ "NO_COLOR",
+ }
+
+ from vllm.config.utils import normalize_value
+
+ factors: dict[str, object] = {}
+ for factor, getter in environment_variables.items():
+ if factor in ignored_factors:
+ continue
+
+ try:
+ raw = getter()
+ except Exception as exc: # pragma: no cover - defensive logging
+ logger.warning(
+ "Skipping environment variable %s while hashing compile factors: %s",
+ factor,
+ exc,
+ )
+ continue
+
+ factors[factor] = normalize_value(raw)
ray_noset_env_vars = [
# Refer to
@@ -1641,8 +1679,8 @@ def compute_hash() -> str:
"RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR",
"RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES",
]
- factors.extend([os.getenv(var) for var in ray_noset_env_vars])
- hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
+ for var in ray_noset_env_vars:
+ factors[var] = normalize_value(os.getenv(var))
- return hash_str
+ return factors
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index 25fb7181a8f29..7cb490e391abb 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -153,7 +153,7 @@ class DPMetadata:
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
- Context mamager for setting self.local_sizes. Same as self.chunked_sizes
+ Context manager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
diff --git a/vllm/logger.py b/vllm/logger.py
index 9341008296843..ad3123c0f0149 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -7,7 +7,8 @@ import json
import logging
import os
import sys
-from collections.abc import Hashable
+from collections.abc import Generator, Hashable
+from contextlib import contextmanager
from functools import lru_cache, partial
from logging import Logger
from logging.config import dictConfig
@@ -17,18 +18,25 @@ from typing import Any, Literal, cast
import vllm.envs as envs
-VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
-VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
-VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
-VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
-VLLM_LOGGING_STREAM = envs.VLLM_LOGGING_STREAM
-
_FORMAT = (
- f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
+ f"{envs.VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
"[%(fileinfo)s:%(lineno)d] %(message)s"
)
_DATE_FORMAT = "%m-%d %H:%M:%S"
+
+def _use_color() -> bool:
+ if envs.NO_COLOR or envs.VLLM_LOGGING_COLOR == "0":
+ return False
+ if envs.VLLM_LOGGING_COLOR == "1":
+ return True
+ if envs.VLLM_LOGGING_STREAM == "ext://sys.stdout": # stdout
+ return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
+ elif envs.VLLM_LOGGING_STREAM == "ext://sys.stderr": # stderr
+ return hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
+ return False
+
+
DEFAULT_LOGGING_CONFIG = {
"formatters": {
"vllm": {
@@ -36,13 +44,19 @@ DEFAULT_LOGGING_CONFIG = {
"datefmt": _DATE_FORMAT,
"format": _FORMAT,
},
+ "vllm_color": {
+ "class": "vllm.logging_utils.ColoredFormatter",
+ "datefmt": _DATE_FORMAT,
+ "format": _FORMAT,
+ },
},
"handlers": {
"vllm": {
"class": "logging.StreamHandler",
- "formatter": "vllm",
- "level": VLLM_LOGGING_LEVEL,
- "stream": VLLM_LOGGING_STREAM,
+ # Choose formatter based on color setting.
+ "formatter": "vllm_color" if _use_color() else "vllm",
+ "level": envs.VLLM_LOGGING_LEVEL,
+ "stream": envs.VLLM_LOGGING_STREAM,
},
},
"loggers": {
@@ -144,7 +158,7 @@ _METHODS_TO_PATCH = {
def _configure_vllm_root_logger() -> None:
logging_config = dict[str, Any]()
- if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
+ if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError(
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
@@ -152,16 +166,22 @@ def _configure_vllm_root_logger() -> None:
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH."
)
- if VLLM_CONFIGURE_LOGGING:
+ if envs.VLLM_CONFIGURE_LOGGING:
logging_config = DEFAULT_LOGGING_CONFIG
- if VLLM_LOGGING_CONFIG_PATH:
- if not path.exists(VLLM_LOGGING_CONFIG_PATH):
+ vllm_handler = logging_config["handlers"]["vllm"]
+ # Refresh these values in case env vars have changed.
+ vllm_handler["level"] = envs.VLLM_LOGGING_LEVEL
+ vllm_handler["stream"] = envs.VLLM_LOGGING_STREAM
+ vllm_handler["formatter"] = "vllm_color" if _use_color() else "vllm"
+
+ if envs.VLLM_LOGGING_CONFIG_PATH:
+ if not path.exists(envs.VLLM_LOGGING_CONFIG_PATH):
raise RuntimeError(
"Could not load logging config. File does not exist: %s",
- VLLM_LOGGING_CONFIG_PATH,
+ envs.VLLM_LOGGING_CONFIG_PATH,
)
- with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
+ with open(envs.VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
custom_config = json.loads(file.read())
if not isinstance(custom_config, dict):
@@ -193,6 +213,14 @@ def init_logger(name: str) -> _VllmLogger:
return cast(_VllmLogger, logger)
+@contextmanager
+def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
+ current_level = logging.root.manager.disable
+ logging.disable(level)
+ yield
+ logging.disable(current_level)
+
+
# The root logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py
index 7202259ca21aa..8d3354df215b1 100644
--- a/vllm/logging_utils/__init__.py
+++ b/vllm/logging_utils/__init__.py
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from vllm.logging_utils.formatter import NewLineFormatter
+from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
+from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime
__all__ = [
"NewLineFormatter",
+ "ColoredFormatter",
+ "lazy",
"logtime",
]
diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py
index 02ba308e18796..3ad4ef8d119ad 100644
--- a/vllm/logging_utils/formatter.py
+++ b/vllm/logging_utils/formatter.py
@@ -75,3 +75,53 @@ class NewLineFormatter(logging.Formatter):
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
+
+
+class ColoredFormatter(NewLineFormatter):
+ """Adds ANSI color codes to log levels for terminal output.
+
+ This formatter adds colors by injecting them into the format string for
+ static elements (timestamp, filename, line number) and modifying the
+ levelname attribute for dynamic color selection.
+ """
+
+ # ANSI color codes
+ COLORS = {
+ "DEBUG": "\033[37m", # White
+ "INFO": "\033[32m", # Green
+ "WARNING": "\033[33m", # Yellow
+ "ERROR": "\033[31m", # Red
+ "CRITICAL": "\033[35m", # Magenta
+ }
+ GREY = "\033[90m" # Grey for timestamp and file info
+ RESET = "\033[0m"
+
+ def __init__(self, fmt, datefmt=None, style="%"):
+ # Inject grey color codes into format string for timestamp and file info
+ if fmt:
+ # Wrap %(asctime)s with grey
+ fmt = fmt.replace("%(asctime)s", f"{self.GREY}%(asctime)s{self.RESET}")
+ # Wrap [%(fileinfo)s:%(lineno)d] with grey
+ fmt = fmt.replace(
+ "[%(fileinfo)s:%(lineno)d]",
+ f"{self.GREY}[%(fileinfo)s:%(lineno)d]{self.RESET}",
+ )
+
+ # Call parent __init__ with potentially modified format string
+ super().__init__(fmt, datefmt, style)
+
+ def format(self, record):
+ # Store original levelname to restore later (in case record is reused)
+ orig_levelname = record.levelname
+
+ # Only modify levelname - it needs dynamic color based on severity
+ if (color_code := self.COLORS.get(record.levelname)) is not None:
+ record.levelname = f"{color_code}{record.levelname}{self.RESET}"
+
+ # Call parent format which will handle everything else
+ msg = super().format(record)
+
+ # Restore original levelname
+ record.levelname = orig_levelname
+
+ return msg
diff --git a/vllm/logging_utils/lazy.py b/vllm/logging_utils/lazy.py
new file mode 100644
index 0000000000000..3ade798962857
--- /dev/null
+++ b/vllm/logging_utils/lazy.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Callable
+from typing import Any
+
+
+class lazy:
+ """Wrap a zero-argument callable evaluated only during log formatting."""
+
+ __slots__ = ("_factory",)
+
+ def __init__(self, factory: Callable[[], Any]) -> None:
+ self._factory = factory
+
+ def __str__(self) -> str:
+ return str(self._factory())
+
+ def __repr__(self) -> str:
+ return str(self)
diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py
index 8a4f5ff175d4f..25364a5881364 100644
--- a/vllm/lora/layers/__init__.py
+++ b/vllm/lora/layers/__init__.py
@@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
-from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
+from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
@@ -38,4 +38,5 @@ __all__ = [
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
+ "FusedMoE3DWithLoRA",
]
diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py
index 0c7e806848892..3bfb88c007622 100644
--- a/vllm/lora/layers/base.py
+++ b/vllm/lora/layers/base.py
@@ -42,9 +42,8 @@ class BaseLayerWithLoRA(nn.Module):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py
index 3db4165e20176..06ecc8d2f634c 100644
--- a/vllm/lora/layers/base_linear.py
+++ b/vllm/lora/layers/base_linear.py
@@ -94,14 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
+ assert isinstance(lora_a, torch.Tensor)
+ assert isinstance(lora_b, torch.Tensor)
assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
)
diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py
index 637ded9b2a0f0..3e21d426c304a 100644
--- a/vllm/lora/layers/column_parallel_linear.py
+++ b/vllm/lora/layers/column_parallel_linear.py
@@ -246,9 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
self.reset_lora(index)
diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py
index 8fb3efa220f6d..0eb6562bec6cd 100644
--- a/vllm/lora/layers/fused_moe.py
+++ b/vllm/lora/layers/fused_moe.py
@@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
+from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -41,6 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
+ self._w13_slices = 2
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
@@ -59,8 +61,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def _get_lora_moe_configs(
self,
op_prefix: str,
- lora_a_stacked: torch.Tensor,
- lora_b_stacked: torch.Tensor,
+ num_loras: int,
+ rank: int,
num_slices: int,
M: int,
layer: FusedMoE,
@@ -68,23 +70,25 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
+ hidden_size = layer.hidden_size
+ intermediate_size = layer.intermediate_size_per_partition
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
- max_loras=lora_a_stacked.shape[0],
+ max_loras=num_loras,
batch=M,
- hidden_size=lora_a_stacked.shape[-1],
- rank=lora_a_stacked.shape[-2],
+ hidden_size=hidden_size,
+ rank=rank,
num_slices=num_slices,
- moe_intermediate_size=lora_b_stacked.shape[-2],
+ moe_intermediate_size=intermediate_size,
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
- max_loras=lora_a_stacked.shape[0],
+ max_loras=num_loras,
batch=M,
- hidden_size=lora_a_stacked.shape[-1],
- rank=lora_a_stacked.shape[-2],
+ hidden_size=hidden_size, # lora_a_stacked.shape[-1],
+ rank=rank,
num_slices=num_slices,
- moe_intermediate_size=lora_b_stacked.shape[-2],
+ moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
@@ -151,12 +155,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
-
+ max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
- lora_a_stacked=self.w1_lora_a_stacked,
- lora_b_stacked=self.w1_lora_b_stacked,
- num_slices=2,
+ num_loras=self.max_loras,
+ rank=max_lora_rank,
+ num_slices=self._w13_slices,
M=M,
layer=layer,
top_k=top_k,
@@ -164,7 +168,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
)
# get the block size of m from customized config or default config
- max_loras = self.w1_lora_a_stacked.shape[0]
(
sorted_token_ids_lora,
expert_ids_lora,
@@ -174,7 +177,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens,
shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
- max_loras,
+ self.max_loras,
self.adapter_enabled,
expert_map,
)
@@ -185,17 +188,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora
)
- w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
- w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
- max_lora_rank = self.w1_lora_a_stacked.shape[-2]
- expert_ids_lora = expert_ids_lora.view(max_loras, -1)
- sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
+ expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
+ sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
+ #
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
- w13_lora_a_stacked,
- w13_lora_b_stacked,
+ self.w13_lora_a_stacked,
+ self.w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
@@ -205,6 +206,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
+ fully_sharded=self.fully_sharded,
)
result = func(*args, **kwargs)
@@ -228,11 +230,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
-
+ max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
- lora_a_stacked=self.w2_lora_a_stacked,
- lora_b_stacked=self.w2_lora_b_stacked,
+ num_loras=self.max_loras,
+ rank=max_lora_rank,
num_slices=1,
M=M,
layer=layer,
@@ -245,17 +247,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
- max_loras = self.w1_lora_a_stacked.shape[0]
- expert_ids_lora = expert_ids_lora.view(max_loras, -1)
- sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
+
+ expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
+ sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
- max_lora_rank = self.w1_lora_a_stacked.shape[-2]
+
+ shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
+
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
- [self.w2_lora_a_stacked],
- [self.w2_lora_b_stacked],
+ self.w2_lora_a_stacked,
+ self.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
@@ -266,6 +270,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
expand_config, ## pass the expand config
self.adapter_enabled,
True,
+ fully_sharded=self.fully_sharded,
+ offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
)
result = func(*args, **kwargs)
@@ -282,11 +288,72 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
-
self.base_layer.quant_method = FusedMoEModularMethod(
self.base_layer.quant_method, m_fused_moe_fn
)
+ def _create_lora_a_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ ):
+ self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ lora_config.max_lora_rank
+ if not self.fully_sharded
+ else divide(lora_config.max_lora_rank, self.tp_size),
+ self.base_layer.hidden_size,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ for _ in range(self._w13_slices)
+ )
+ self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ lora_config.max_lora_rank,
+ self.base_layer.intermediate_size_per_partition,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ),
+ )
+
+ def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
+ self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.intermediate_size_per_partition,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ for _ in range(self._w13_slices)
+ )
+ self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.hidden_size
+ if not self.fully_sharded
+ else divide(self.base_layer.hidden_size, self.tp_size),
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ),
+ )
+
def create_lora_weights(
self,
max_loras: int,
@@ -294,108 +361,63 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
+ self.max_loras = lora_config.max_loras
+ self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
- self.w1_lora_a_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- lora_config.max_lora_rank,
- self.base_layer.hidden_size,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
- self.w1_lora_b_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- self.base_layer.intermediate_size_per_partition,
- lora_config.max_lora_rank,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
-
- self.w2_lora_a_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- lora_config.max_lora_rank,
- self.base_layer.intermediate_size_per_partition,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
- self.w2_lora_b_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- self.base_layer.hidden_size,
- lora_config.max_lora_rank,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
-
- self.w3_lora_a_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- lora_config.max_lora_rank,
- self.base_layer.hidden_size,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
- self.w3_lora_b_stacked = torch.zeros(
- (
- max_loras,
- self.base_layer.local_num_experts,
- self.base_layer.intermediate_size_per_partition,
- lora_config.max_lora_rank,
- ),
- dtype=lora_config.lora_dtype,
- device=self.device,
- )
-
+ self._create_lora_a_weights(max_loras, lora_config)
+ self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
+ # TODO Optimize this section
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj
- self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
- self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
- self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
+ self.lora_a_stacked.append(
+ self.w13_lora_a_stacked[0][lora_id][experts_id]
+ )
+ self.lora_a_stacked.append(
+ self.w2_lora_a_stacked[0][lora_id][experts_id]
+ )
- self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
- self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
- self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
+ self.lora_b_stacked.append(
+ self.w13_lora_b_stacked[0][lora_id][experts_id]
+ )
+ self.lora_b_stacked.append(
+ self.w2_lora_b_stacked[0][lora_id][experts_id]
+ )
+
+ self.lora_a_stacked.append(
+ self.w13_lora_a_stacked[1][lora_id][experts_id]
+ )
+ self.lora_b_stacked.append(
+ self.w13_lora_b_stacked[1][lora_id][experts_id]
+ )
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
- self.w1_lora_a_stacked[index] = 0
- self.w1_lora_b_stacked[index] = 0
- self.w3_lora_a_stacked[index] = 0
- self.w3_lora_b_stacked[index] = 0
- self.w2_lora_a_stacked[index] = 0
- self.w2_lora_b_stacked[index] = 0
+ for pos in range(self._w13_slices):
+ self.w13_lora_a_stacked[pos][index] = 0
+ self.w13_lora_b_stacked[pos][index] = 0
+
+ self.w2_lora_a_stacked[0][index] = 0
+ self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
- bias: torch.Tensor | None = None,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
+ assert isinstance(lora_a, list)
+ assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
@@ -419,39 +441,44 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
- self.w1_lora_a_stacked[
+ if self.fully_sharded:
+ # Based on S-LoRA, we slice W1 and W3 A along the rank dim,
+ # and W2 B along the hidden_size dim.
+ w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
+ w13_start_idx = self.tp_rank * w13_shard_size
+ w13_end_idx = (self.tp_rank + 1) * w13_shard_size
+ w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
+ w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
+
+ w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
+ w2_start_idx = self.tp_rank * w2_shard_size
+ w2_end_idx = (self.tp_rank + 1) * w2_shard_size
+ w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
+ # w1 lora_a
+ self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
-
- self.w3_lora_a_stacked[
+ # w3 lora_a
+ self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
- self.w2_lora_b_stacked[
- index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
- ].copy_(w2_lora_b, non_blocking=True)
-
- self.w1_lora_b_stacked[
+ # w1 lora_b
+ self.w13_lora_b_stacked[0][
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
- self.w3_lora_b_stacked[
+ # w3 lora_b
+ self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
- self.w2_lora_a_stacked[
+
+ self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
- @classmethod
- def can_replace_layer(
- cls,
- source_layer: nn.Module,
- lora_config: LoRAConfig,
- packed_modules_list: list,
- model_config: PretrainedConfig | None,
- ) -> bool:
- """Returns True if the layer can be replaced by this LoRA layer."""
- # return type(source_layer) is FusedMoE
- return isinstance(source_layer, FusedMoE)
+ self.w2_lora_b_stacked[0][
+ index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
+ ].copy_(w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
@@ -470,3 +497,220 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: PretrainedConfig | None,
+ ) -> bool:
+ """Returns True if the layer can be replaced by this LoRA layer."""
+ # return type(source_layer) is FusedMoE
+
+ return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
+
+
+class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
+ def __init__(self, base_layer):
+ super().__init__(base_layer)
+ self._w13_slices = 1
+
+ def _create_lora_b_weights(self, max_loras, lora_config):
+ self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.intermediate_size_per_partition * 2,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ for _ in range(self._w13_slices)
+ )
+ self.w2_lora_b_stacked: tuple[torch.Tensor] = (
+ torch.zeros(
+ (
+ max_loras,
+ self.base_layer.local_num_experts,
+ self.base_layer.hidden_size
+ if not self.fully_sharded
+ else divide(self.base_layer.hidden_size, self.tp_size),
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ),
+ )
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: PretrainedConfig | None = None,
+ ) -> None:
+ """Initializes lora matrices."""
+ self.max_loras = lora_config.max_loras
+ self.fully_sharded = lora_config.fully_sharded_loras
+
+ self.adapter_enabled = torch.tensor(
+ [0] * (max_loras + 1), dtype=torch.int, device=self.device
+ )
+
+ self._create_lora_a_weights(max_loras, lora_config)
+ self._create_lora_b_weights(max_loras, lora_config)
+
+ def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
+ if self.tp_size == 1 or not self.fully_sharded:
+ return w13_lora_a
+
+ # w13_lora_a shape (num_experts,rank,input_size)
+ current_lora_rank = w13_lora_a.shape[1]
+ assert current_lora_rank % self.tp_size == 0
+
+ sliced_rank = current_lora_rank // self.tp_size
+ start_idx = self.tp_rank * sliced_rank
+ end_idx = (self.tp_rank + 1) * sliced_rank
+ return w13_lora_a[:, start_idx:end_idx, :]
+
+ def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
+ if self.tp_size == 1:
+ return w13_lora_b
+
+ # w13_lora_b shape (num_experts,output_size,rank)
+ shard_size = self.base_layer.intermediate_size_per_partition
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+ if is_interleave:
+ # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
+ # in the interleaved order, and corresponding LoRA need to be processed.
+ w1_lora_b = w13_lora_b[:, ::2, :]
+ w3_lora_b = w13_lora_b[:, 1::2, :]
+ sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
+ sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
+
+ return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
+ 1, 2
+ )
+ else:
+ slice_size = w13_lora_b.shape[1] // 2
+ w1_lora_b = w13_lora_b[:, :slice_size, :]
+ w3_lora_b = w13_lora_b[:, slice_size:, :]
+ sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
+ sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
+
+ return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
+
+ def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
+ if self.tp_size == 1:
+ return w2_lora_a
+ # w2_lora_a shape (num_experts,rank,input_size)
+ shard_size = self.base_layer.intermediate_size_per_partition
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+
+ return w2_lora_a[:, :, start_idx:end_idx]
+
+ def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
+ if self.tp_size == 1 or not self.fully_sharded:
+ return w2_lora_b
+ # Based on S-LoRA, we slice W2 B along the hidden_size dim.
+ # w2_lora_b shape (num_experts,output_size,rank)
+ current_lora_size = w2_lora_b.shape[1]
+
+ sliced_size = current_lora_size // self.tp_size
+ start_idx = self.tp_rank * sliced_size
+ end_idx = (self.tp_rank + 1) * sliced_size
+ return w2_lora_b[:, start_idx:end_idx, :]
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
+ ):
+ """Overwrites lora tensors at index."""
+ # Make mypy happy
+ assert isinstance(lora_a, list)
+ assert isinstance(lora_b, list)
+ assert len(lora_a) == len(lora_b) == 2
+
+ self.reset_lora(index)
+ self.adapter_enabled[index] = 1
+
+ num_experts = self.w13_lora_a_stacked[0].shape[1]
+ w13_lora_a, w2_lora_a = lora_a
+ w13_lora_b, w2_lora_b = lora_b
+
+ # (num_experts,rank,input_size)
+ w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
+ w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
+ # (output_size,num_experts,rank)
+ w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
+ w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
+ # (num_experts,output_size,rank)
+ w13_lora_b = w13_lora_b.permute(1, 0, 2)
+ w2_lora_b = w2_lora_b.permute(1, 0, 2)
+
+ sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
+ sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
+
+ sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
+ sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
+
+ self.w13_lora_a_stacked[0][
+ index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
+ ].copy_(sliced_w13_lora_a, non_blocking=True)
+ self.w2_lora_a_stacked[0][
+ index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
+ ].copy_(sliced_w2_lora_a, non_blocking=True)
+
+ self.w13_lora_b_stacked[0][
+ index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
+ ].copy_(sliced_w13_lora_b, non_blocking=True)
+ self.w2_lora_b_stacked[0][
+ index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
+ ].copy_(sliced_w2_lora_b, non_blocking=True)
+
+ @property
+ def w13_input_size(self):
+ """
+ Full size
+ """
+ return self.w13_lora_a_stacked[0].shape[-1]
+
+ @property
+ def w13_output_size(self):
+ """
+ Full size
+ """
+ return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
+
+ @property
+ def w2_input_size(self):
+ """
+ Full size
+ """
+ return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
+
+ @property
+ def w2_output_size(self):
+ """
+ Full size
+ """
+ return self.w2_lora_a_stacked[0].shape[-2]
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: list,
+ model_config: PretrainedConfig | None,
+ ) -> bool:
+ """Returns True if the layer can be replaced by this LoRA layer."""
+
+ return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py
index adc5e861f57fb..c01984db4e64c 100644
--- a/vllm/lora/layers/logits_processor.py
+++ b/vllm/lora/layers/logits_processor.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import math
import torch
import torch.nn as nn
@@ -108,22 +107,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
(
max_loras,
1,
- # Pad for kernel compatibility
- math.ceil(
- self.base_layer.vocab_size / lora_config.lora_vocab_padding_size
- )
- * lora_config.lora_vocab_padding_size,
+ self.base_layer.vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
- self.embeddings_tensors = torch.full(
- (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
- fill_value=float("-inf"),
- dtype=self.dtype,
- device=self.device,
- )
+
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
@@ -134,15 +124,15 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
- self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
+ assert isinstance(lora_a, torch.Tensor)
+ assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
@@ -150,12 +140,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
- if embeddings_tensor is not None:
- self.embeddings_tensors[
- index,
- : embeddings_tensor.shape[0],
- : embeddings_tensor.shape[1],
- ] = embeddings_tensor
def _get_logits(
self,
@@ -193,39 +177,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
- lora_logits = torch.empty(
- self.embeddings_tensors.shape[0] + 1,
- self.embeddings_tensors.shape[1],
- hidden_states.shape[0],
- dtype=self.embeddings_tensors.dtype,
- device=self.embeddings_tensors.device,
- )
- torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1])
-
- neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype)
-
- lora_logits[-1] = neg_inf
- lora_logits = lora_logits.mT
- indices_padded = self.punica_wrapper.sampler_indices_padded
-
- if current_platform.is_tpu() or current_platform.is_xpu():
- indices_padded = indices_padded[: logits.size(0)]
-
- lora_logits = (
- lora_logits.reshape(
- lora_logits.shape[0] * lora_logits.shape[1],
- lora_logits.shape[2],
- )
- .index_select(0, indices_padded)
- .nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)
- )
-
- logits[
- :,
- self.base_layer.org_vocab_size : self.base_layer.org_vocab_size
- + lora_logits.shape[1],
- ] = lora_logits
-
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
)
diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py
index 2ef1bd98fc612..95517b1aee263 100644
--- a/vllm/lora/layers/row_parallel_linear.py
+++ b/vllm/lora/layers/row_parallel_linear.py
@@ -63,23 +63,18 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
- output_parallel = self.apply(input_parallel)
+ bias_ = (
+ None
+ if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
+ else self.base_layer.bias
+ )
+ output_parallel = self.apply(input_parallel, bias_)
if self.base_layer.reduce_results and self.tp_size > 1:
- output_ = tensor_model_parallel_all_reduce(output_parallel)
+ output = tensor_model_parallel_all_reduce(output_parallel)
else:
- output_ = output_parallel
-
- if not self.base_layer.skip_bias_add:
- output = (
- output_ + self.base_layer.bias
- if self.base_layer.bias is not None
- else output_
- )
- output_bias = None
- else:
- output = output_
- output_bias = self.base_layer.bias
+ output = output_parallel
+ output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
if not self.base_layer.return_bias:
return output
@@ -120,7 +115,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
return lora_b
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
- output = self.base_layer.quant_method.apply(self.base_layer, x)
+ output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py
index ca4ad8012e9c3..c87ca9e24dece 100644
--- a/vllm/lora/layers/vocal_parallel_embedding.py
+++ b/vllm/lora/layers/vocal_parallel_embedding.py
@@ -46,19 +46,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_slice = None
self.embeddings_weights = None
- self.embeddings_tensors = torch.zeros(
- (
- max_loras,
- lora_config.lora_extra_vocab_size,
- self.base_layer.embedding_dim,
- ),
- dtype=self.base_layer.weight.dtype,
- device=self.base_layer.weight.device,
- )
self.lora_a_stacked = torch.zeros(
(
max_loras,
- self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
+ self.base_layer.org_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
@@ -82,54 +73,37 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
- self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
- lora_a: torch.Tensor,
- lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None,
+ lora_a: torch.Tensor | list[torch.Tensor],
+ lora_b: torch.Tensor | list[torch.Tensor],
):
+ assert isinstance(lora_a, torch.Tensor)
+ assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
+
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True
)
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
- if embeddings_tensor is not None:
- self.embeddings_tensors[
- index,
- : embeddings_tensor.shape[0],
- : embeddings_tensor.shape[1],
- ].copy_(embeddings_tensor, non_blocking=True)
- if self.embeddings_slice is not None:
- # TODO(yard1): Optimize this copy, we don't need to copy
- # everything, just the modified part
- embeddings = self.embeddings_tensors.view(
- self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1],
- self.embeddings_tensors.shape[2],
- )[self.embeddings_slice[0] : self.embeddings_slice[1]]
- assert self.embeddings_weights is not None
- self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0)
-
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
- indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices_1,
self.lora_a_stacked_2d,
)
- full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask))
+ full_output = self.base_layer.forward(x)
full_output_org = full_output
if full_output.ndim == 3:
diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py
index 7691481d5039e..f0d8e22194050 100644
--- a/vllm/lora/lora_weights.py
+++ b/vllm/lora/lora_weights.py
@@ -21,7 +21,6 @@ class LoRALayerWeights:
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
- embeddings_tensor: torch.Tensor | None = None,
scaling: float | None = None,
) -> None:
self.module_name = module_name
@@ -29,7 +28,6 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
- self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
@@ -56,18 +54,11 @@ class LoRALayerWeights:
def is_packed(self) -> bool:
return False
- @property
- def extra_vocab_size(self) -> int:
- return (
- self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0
- )
-
@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
- embeddings_tensor: torch.Tensor | None = None,
) -> "LoRALayerWeights":
# lora_a and lora_b are set to None for config-based construction
return cls(
@@ -76,7 +67,6 @@ class LoRALayerWeights:
peft_helper.lora_alpha,
None,
None,
- embeddings_tensor,
peft_helper.vllm_lora_scaling_factor,
)
@@ -89,7 +79,6 @@ class LoRALayerWeights:
rank: int,
dtype: torch.dtype,
device: torch.types.Device,
- embeddings_tensor_dim: int | None = None,
) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros(
@@ -99,24 +88,12 @@ class LoRALayerWeights:
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
)
- embeddings_tensor = (
- torch.rand(
- 10,
- embeddings_tensor_dim,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory,
- )
- if embeddings_tensor_dim
- else None
- )
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
- embeddings_tensor=embeddings_tensor,
)
@@ -139,7 +116,6 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling, # type: ignore
- embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 02c252f15bfab..636f062feb7b0 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -21,11 +21,14 @@ from vllm.lora.utils import (
from_layer,
from_layer_logits_processor,
get_supported_lora_modules,
+ is_base_embeddding_weights,
+ is_moe_model,
is_regex_target_modules,
parse_fine_tuned_lora_name,
process_packed_modules_mapping,
replace_submodule,
)
+from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
@@ -93,14 +96,6 @@ class LoRAModel:
loras=self.loras.copy(),
)
- @property
- def extra_vocab_size(self) -> int:
- return (
- max(lora.extra_vocab_size for lora in self.loras.values())
- if self.loras
- else 0
- )
-
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
@@ -117,7 +112,6 @@ class LoRAModel:
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
- embeddings: dict[str, torch.Tensor] | None = None,
target_embedding_padding: int | None = None,
embedding_modules: dict[str, str] | None = None,
embedding_padding_modules: list[str] | None = None,
@@ -127,24 +121,14 @@ class LoRAModel:
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
+ if is_base_embeddding_weights(tensor_name):
+ continue
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
if module_name not in loras:
- lora_embeddings_tensor = None
- if embeddings:
- assert embedding_modules is not None
- embeddings_module = next(
- (k for k in embedding_modules if k in module_name), None
- )
- if embeddings_module:
- lora_embeddings_tensor = embeddings[
- embedding_modules[embeddings_module]
- ].to(device=device, dtype=dtype)
- if pin_memory:
- lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
loras[module_name] = LoRALayerWeights.from_config(
- module_name, peft_helper, lora_embeddings_tensor
+ module_name, peft_helper
)
if is_lora_a:
@@ -206,15 +190,17 @@ class LoRAModel:
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
- new_embeddings_tensor_path = os.path.join(
- lora_dir, "new_embeddings.safetensors"
- )
- new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
+ # new_embeddings_tensor_path = os.path.join(
+ # lora_dir, "new_embeddings.safetensors"
+ # )
+ # new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
+ if is_base_embeddding_weights(lora_module):
+ continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
@@ -300,21 +286,12 @@ class LoRAModel:
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
- embeddings = None
- if os.path.isfile(new_embeddings_tensor_path):
- embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
- elif os.path.isfile(new_embeddings_bin_file_path):
- embeddings = torch.load(
- new_embeddings_bin_file_path, map_location=device, weights_only=True
- )
-
return cls.from_lora_tensors(
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
- embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
@@ -381,7 +358,11 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
+ self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
+ self.model, "is_3d_moe_weight"
+ )
self._create_lora_modules()
+
self.model.lora_manager = self
def __len__(self) -> int:
@@ -425,22 +406,36 @@ class LoRAModelManager:
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
- if module_lora:
- # Note (gnovack) - If MOE lora weights are not split into
- # num_experts chunks, we split them here
- if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
- module_lora.lora_a
- ):
- # Handle FSDP file format where experts.base_layer is the
- # gate_up_proj and experts is the down_proj
- gate_up_proj_lora = self._get_lora_layer_weights(
- lora_model, module_name + ".base_layer"
- )
-
- assert gate_up_proj_lora is not None
- assert module_lora is not None
-
- down_proj_lora = module_lora
+ if not module_lora:
+ module.reset_lora(index)
+ continue
+ # Note (gnovack) - If MOE lora weights are not split into
+ # num_experts chunks, we split them here
+ if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
+ module_lora.lora_a
+ ):
+ # Handle PEFT file format where experts.base_layer is the
+ # gate_up_proj and experts is the down_proj
+ gate_up_proj_lora = self._get_lora_layer_weights(
+ lora_model, module_name + ".base_layer"
+ )
+ down_proj_lora = module_lora
+ # FIXME Edge case where LoRA is not added to gate_up_proj
+ # or down_proj
+ assert gate_up_proj_lora is not None
+ assert down_proj_lora is not None
+ if self._is_3d_moe_model:
+ module_lora.lora_a = [
+ gate_up_proj_lora.lora_a,
+ down_proj_lora.lora_a,
+ ]
+ module_lora.lora_b = [
+ gate_up_proj_lora.lora_b,
+ down_proj_lora.lora_b,
+ ]
+ else:
+ # Some 3D MoE models haven't added the `is_3d_moe_weight`
+ # attribute yet, so fallback here
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
@@ -469,15 +464,12 @@ class LoRAModelManager:
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
+ module.set_lora(
+ index,
+ module_lora.lora_a,
+ module_lora.lora_b,
+ )
- module.set_lora(
- index,
- module_lora.lora_a,
- module_lora.lora_b,
- module_lora.embeddings_tensor,
- )
- else:
- module.reset_lora(index)
return True
def _deactivate_adapter(self, lora_id: int):
@@ -505,7 +497,6 @@ class LoRAModelManager:
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
- self.lora_config.lora_extra_vocab_size,
)
def remove_all_adapters(self):
@@ -539,6 +530,13 @@ class LoRAModelManager:
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
+ if isinstance(module, FusedMoE):
+ # packed_moduled_lst is used here to just determine whether to
+ # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
+ # difference between these two LoRA layers is whether the
+ # LoRA weights of w1 and w3 have already been fused on disk.
+
+ packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
new_module = replace_submodule(
self.model,
module_name,
@@ -587,6 +585,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
+ pass
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
@@ -616,7 +615,6 @@ class LoRAModelManager:
if parts[-1] in embedding_modules:
input_dim = (
module.base_layer.org_vocab_size
- + self.lora_config.lora_extra_vocab_size
if hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1]
)
@@ -625,11 +623,6 @@ class LoRAModelManager:
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[0]
)
- embeddings_tensor_dim = (
- module.base_layer.embedding_dim
- if hasattr(module.base_layer, "embedding_dim")
- else module.base_layer.weight.shape[1]
- )
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
@@ -637,8 +630,31 @@ class LoRAModelManager:
rank,
module.lora_a_stacked[0].dtype,
"cpu",
- embeddings_tensor_dim=embeddings_tensor_dim,
)
+ model.loras[module_name] = lora
+ elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
+ # Case for 3D moe model
+ # w2
+ lora = LoRALayerWeights.create_dummy_lora_weights(
+ module_name,
+ module.w2_input_size,
+ module.w2_output_size,
+ rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts
+ module.w2_lora_a_stacked[0].dtype,
+ "cpu",
+ )
+ model.loras[module_name] = lora
+ # w13
+ lora = LoRALayerWeights.create_dummy_lora_weights(
+ module_name,
+ module.w13_input_size,
+ module.w13_output_size,
+ rank
+ * module.w13_lora_a_stacked[0].shape[1], # rank*num_experts
+ module.w13_lora_a_stacked[0].dtype,
+ "cpu",
+ )
+ model.loras[module_name + ".base_layer"] = lora
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
@@ -648,6 +664,7 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype,
"cpu",
)
+ model.loras[module_name] = lora
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
@@ -663,7 +680,7 @@ class LoRAModelManager:
)
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
- model.loras[module_name] = lora
+ model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
index e2dd47dbb4e64..413ee8ecbbf96 100644
--- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@@ -3,6 +3,10 @@
import torch
+from vllm.distributed import (
+ tensor_model_parallel_all_gather,
+ tensor_model_parallel_all_reduce,
+)
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
@@ -311,6 +315,7 @@ def _fused_moe_lora_expand(
num_stages: int,
split_k: int,
mul_routed_weight: bool = False,
+ offset: int = 0,
) -> None:
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
@@ -380,7 +385,7 @@ def _fused_moe_lora_expand(
**expand_config,
)
for i in range(num_slices):
- output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
+ output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
@torch.inference_mode()
@@ -416,6 +421,8 @@ def _fused_moe_lora(
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
+ fully_sharded: bool = False,
+ offset: int = 0,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert (
@@ -430,7 +437,6 @@ def _fused_moe_lora(
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
- assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
@@ -480,6 +486,19 @@ def _fused_moe_lora(
mul_routed_weight,
)
+ if fully_sharded:
+ if max_lora_rank == w1_lora_b_stacked.shape[-1]:
+ a_intermediate_cache1 = tensor_model_parallel_all_reduce(
+ a_intermediate_cache1
+ )
+ else:
+ a_intermediate_cache1 = tensor_model_parallel_all_gather(
+ a_intermediate_cache1
+ )
+
+ # reset max_lora_rank to the full rank after allgather
+ max_lora_rank = a_intermediate_cache1.shape[-1]
+
_fused_moe_lora_expand(
output,
a_intermediate_cache1,
@@ -510,6 +529,7 @@ def _fused_moe_lora(
expand_num_stages,
expand_split_k,
mul_routed_weight,
+ offset,
)
diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py
index b6186e8561529..ce38751e4b6a7 100644
--- a/vllm/lora/punica_wrapper/punica_base.py
+++ b/vllm/lora/punica_wrapper/punica_base.py
@@ -31,7 +31,6 @@ class PunicaWrapperABC(ABC):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
- extra_vocab_size: int,
**kwargs,
) -> None:
"""
@@ -172,8 +171,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
- extra_vocab_size: int,
):
+ # NOTE We have remove lora extra vocab support for now. So we set
+ # extra_vocab_size alwayzs to 0, and extra_vocab_size will be removed.
+
+ extra_vocab_size = 0
(
base_indices,
sampler_indices,
@@ -285,12 +287,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
- extra_vocab_size: int,
**kwargs,
):
- self._update_base_metadata(
- mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
- )
+ self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
@@ -471,8 +470,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self,
y: torch.Tensor,
x: torch.Tensor,
- lora_a_stacked: list[torch.Tensor],
- lora_b_stacked: list[torch.Tensor],
+ lora_a_stacked: tuple[torch.Tensor, ...],
+ lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
@@ -483,6 +482,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
+ fully_sharded: bool = False,
+ offset: int = 0,
):
"""
Performs a fused forward computation for LoRA of
diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py
index ede50a48af985..ef4b4ab7c3497 100644
--- a/vllm/lora/punica_wrapper/punica_gpu.py
+++ b/vllm/lora/punica_wrapper/punica_gpu.py
@@ -65,13 +65,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
- extra_vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
- self._update_base_metadata(
- mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
- )
+ self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
@@ -363,8 +360,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self,
y: torch.Tensor,
x: torch.Tensor,
- lora_a_stacked: list[torch.Tensor],
- lora_b_stacked: list[torch.Tensor],
+ lora_a_stacked: tuple[torch.Tensor, ...],
+ lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
@@ -375,6 +372,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
+ fully_sharded: bool = False,
+ offset: int = 0,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
@@ -408,4 +407,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
expand_config.get("NUM_STAGES", 3),
expand_config.get("SPLIT_K", 1),
mul_routed_weight,
+ fully_sharded,
+ offset,
)
diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py
index 090878dcd2546..0888772db54e7 100644
--- a/vllm/lora/punica_wrapper/punica_tpu.py
+++ b/vllm/lora/punica_wrapper/punica_tpu.py
@@ -292,7 +292,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
- extra_vocab_size: int,
):
# Make sure we don't accidentally collect outside operations
torch_xla.sync()
@@ -313,7 +312,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
lora_index_to_id,
max_loras,
vocab_size,
- extra_vocab_size,
+ 0, # extra_vocab_size
"cpu",
)
self._token_lora_indices = self._pad_to_shape(
diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py
index b95087d0ff834..00c00782896cf 100644
--- a/vllm/lora/punica_wrapper/punica_xpu.py
+++ b/vllm/lora/punica_wrapper/punica_xpu.py
@@ -43,13 +43,10 @@ class PunicaWrapperXPU(PunicaWrapperBase):
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
- extra_vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
- self._update_base_metadata(
- mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
- )
+ self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py
index 0f43ff06d8f2b..12524994d4968 100644
--- a/vllm/lora/utils.py
+++ b/vllm/lora/utils.py
@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
+ FusedMoE3DWithLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
@@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
+ FusedMoE3DWithLoRA,
}
@@ -166,6 +168,16 @@ def parse_fine_tuned_lora_name(
raise ValueError(f"{name} is unsupported LoRA weight")
+def is_base_embeddding_weights(name: str) -> bool:
+ # hardcoded subfixes for input & output embedding weights
+ input_embedding_subfix = ".embed_tokens.base_layer.weight"
+ output_embedding_subfix = ".lm_head.base_layer.weight"
+
+ return name.endswith(input_embedding_subfix) or name.endswith(
+ output_embedding_subfix
+ )
+
+
def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str]
) -> bool:
@@ -278,10 +290,12 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
-
- packed_modules_mapping["experts"] = [
- weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
- ]
+ if not hasattr(model, "is_3d_moe_weight"):
+ # 3D MoE LoRA does not need `packed_modules_mapping`
+ packed_modules_mapping["experts"] = [
+ weight_name.rstrip(".")
+ for _, weight_name, _, _ in moe_packed_mapping
+ ]
return packed_modules_mapping
else:
diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py
index b85151f2c7592..4cc201a6414f1 100644
--- a/vllm/lora/worker_manager.py
+++ b/vllm/lora/worker_manager.py
@@ -121,8 +121,7 @@ class WorkerLoRAManager:
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
- target_embedding_padding=self.vocab_size
- + self.lora_config.lora_extra_vocab_size,
+ target_embedding_padding=self.vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
@@ -143,12 +142,6 @@ class WorkerLoRAManager:
# For BadRequestError
raise e
- if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
- raise ValueError(
- f"LoRA added vocab size {lora.extra_vocab_size} "
- f"is greater than lora_extra_vocab_size "
- f"{self.lora_config.lora_extra_vocab_size}."
- )
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py
index 7920d117de5e0..be7f673e5618f 100644
--- a/vllm/model_executor/layers/batch_invariant.py
+++ b/vllm/model_executor/layers/batch_invariant.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
-from functools import cache
from typing import Any
import torch
@@ -785,16 +784,19 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
-@cache
-def vllm_is_batch_invariant():
- env_key = "VLLM_BATCH_INVARIANT"
- is_overridden = False
- val = os.getenv(env_key, "0")
+def _read_vllm_batch_invariant() -> bool:
+ val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
- is_overridden = int(val) != 0
+ return int(val) != 0
except ValueError:
- is_overridden = False
- return is_overridden
+ return False
+
+
+VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
+
+
+def vllm_is_batch_invariant() -> bool:
+ return VLLM_BATCH_INVARIANT
def override_envs_for_invariance():
@@ -803,26 +805,26 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
- "FLASHINFER_MLA",
- "TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
+ # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
+ # "TRITON_MLA",
]
if curr_attn_backend not in supported_backends:
- warning = (
- "Forcibly updating attention backend to"
- f" {supported_backends[0]} for batch_invariant. "
- f" Supported backends: {supported_backends}."
+ error = (
+ "VLLM batch_invariant mode requires an attention backend in "
+ f"{supported_backends}, but got '{curr_attn_backend}'. "
+ "Please set the 'VLLM_ATTENTION_BACKEND' environment variable "
+ "to one of the supported backends before enabling batch_invariant."
)
- logger.warning_once(warning)
- os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
+ raise RuntimeError(error)
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
- logger.warning_once(warning)
+ logger.warning_once(warning, scope="local")
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@@ -850,5 +852,6 @@ def init_batch_invariance():
enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding
- torch.backends.cuda.matmul.allow_tf32 = False
- torch.backends.cudnn.allow_tf32 = False
+ torch.backends.cuda.matmul.fp32_precision = "ieee"
+ torch.backends.cudnn.conv.fp32_precision = "ieee"
+ torch.backends.cudnn.rnn.fp32_precision = "ieee"
diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py
index 2dd625054339c..86c50f39f0076 100644
--- a/vllm/model_executor/layers/fused_moe/all2all_utils.py
+++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py
@@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size(
def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if not moe.moe_parallel_config.use_all2all_kernels:
return None
@@ -134,6 +135,13 @@ def maybe_make_prepare_finalize(
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
+ global_to_physical = physical_to_global = local_expert_global_ids = None
+ if routing_tables is not None:
+ (
+ global_to_physical,
+ physical_to_global,
+ local_expert_global_ids,
+ ) = routing_tables
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
@@ -155,6 +163,9 @@ def maybe_make_prepare_finalize(
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
+ global_to_physical=global_to_physical,
+ physical_to_global=physical_to_global,
+ local_expert_global_ids=local_expert_global_ids,
)
return prepare_finalize
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index a7bd64b1c65e9..1826fafa8c4f5 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -8,7 +8,11 @@ import torch
import vllm.envs as envs
from vllm.config import ParallelConfig
-from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
+from vllm.distributed import (
+ get_dp_group,
+ get_pcp_group,
+ get_tensor_model_parallel_rank,
+)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES,
@@ -24,10 +28,11 @@ logger = init_logger(__name__)
if has_triton_kernels():
try:
from triton_kernels.matmul_ogs import PrecisionConfig
- except ImportError:
+ except (ImportError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
- "version is compatible."
+ "version is compatible. Error: %s",
+ e,
)
@@ -684,9 +689,11 @@ FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
@dataclass
class FusedMoEParallelConfig:
tp_size: int
+ pcp_size: int
dp_size: int
ep_size: int
tp_rank: int
+ pcp_rank: int
dp_rank: int
ep_rank: int
@@ -713,19 +720,22 @@ class FusedMoEParallelConfig:
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@staticmethod
- def flatten_tp_across_dp(
- tp_size: int, dp_size: int, dp_rank: int
+ def flatten_tp_across_dp_and_pcp(
+ tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
) -> tuple[int, int]:
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
- # There are actually dp_size * tp_size devices. Update tp_size
- # and tp_rank so we shard across all devices.
- flatten_tp_size = dp_size * tp_size
- flatten_tp_rank = dp_rank * tp_size + tp_rank
+ # There are actually dp_size * pcp_size * tp_size devices.
+ # Update tp_size and tp_rank so we shard across all devices.
+ flatten_tp_size = dp_size * pcp_size * tp_size
+ flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
return flatten_tp_size, flatten_tp_rank
@staticmethod
def make(
- tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
+ tp_size_: int,
+ pcp_size_: int,
+ dp_size_: int,
+ vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input `tp_size_`,
@@ -734,19 +744,22 @@ class FusedMoEParallelConfig:
Args:
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
+ pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor.
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vLLM's parallel config
object which contains the `enable_expert_parallel` flag.
Examples:
When there is no parallelism requested,
- i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
+ i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes
unaltered and the ranks set to 0.
- Expert Parallelism is considered only when either `dp_size_` or
+ Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
`tp_size_` is non trivial.
- When TP = 2, DP = 1 and EP = False, the configuration on different
+ Note that PCP serves the same function as DP here.
+
+ When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
@@ -754,7 +767,7 @@ class FusedMoEParallelConfig:
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
- When TP = 1, DP = 2 and EP = False, the configuration on different
+ When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
@@ -762,7 +775,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
- When TP = 2, DP = 2 and EP = False, the configuration on different
+ When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
devices:
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
@@ -772,14 +785,14 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
- When, TP = 2, DP = 1 and EP = True, the configuration on different
+ When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
- When, TP = 1, DP = 2 and EP = True, the configuration on different
+ When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
@@ -787,7 +800,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
- When TP = 2, DP = 2 and EP = True, the configuration on different
+ When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
@@ -798,18 +811,25 @@ class FusedMoEParallelConfig:
between the 4 devices.
"""
- use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
+ use_ep = (
+ dp_size_ * pcp_size_ * tp_size_ > 1
+ and vllm_parallel_config.enable_expert_parallel
+ )
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
- tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
- tp_size_, dp_size_, dp_rank
+ pcp_size = pcp_size_
+ pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
+ tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
+ tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
)
if not use_ep:
return FusedMoEParallelConfig(
tp_size=tp_size,
tp_rank=tp_rank,
+ pcp_size=pcp_size,
+ pcp_rank=pcp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
@@ -826,6 +846,8 @@ class FusedMoEParallelConfig:
return FusedMoEParallelConfig(
tp_size=1,
tp_rank=0,
+ pcp_size=pcp_size,
+ pcp_rank=pcp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
new file mode 100644
index 0000000000000..54fe5374cb95d
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
@@ -0,0 +1,147 @@
+{
+ "triton_version": "3.5.0",
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json
new file mode 100644
index 0000000000000..8b78f87e7f73b
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json
@@ -0,0 +1,147 @@
+{
+ "triton_version": "3.5.0",
+ "1": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 2
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json
index 6fcf408755f5d..532c16e899269 100644
--- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json
+++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -1,11 +1,11 @@
{
"1": {
"BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
+ "GROUP_SIZE_M": 16,
"num_warps": 4,
- "num_stages": 5
+ "num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
@@ -13,82 +13,82 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
- "num_stages": 3
+ "num_stages": 4
},
"4": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
- "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
- "16": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 256,
- "BLOCK_SIZE_K": 64,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "24": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "32": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "48": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
"64": {
- "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
+ "GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "128": {
- "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
"256": {
- "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
@@ -96,10 +96,10 @@
"num_stages": 3
},
"512": {
- "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
+ "GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
@@ -109,7 +109,7 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
- "num_stages": 3
+ "num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
@@ -117,21 +117,21 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
- "num_stages": 3
+ "num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 64,
+ "GROUP_SIZE_M": 32,
"num_warps": 4,
- "num_stages": 3
+ "num_stages": 4
},
"3072": {
- "BLOCK_SIZE_M": 128,
- "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
+ "GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
@@ -139,7 +139,7 @@
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 64,
+ "GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
index 572307052b489..659a2d4ee5b39 100644
--- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
@@ -6,22 +6,7 @@ import torch
from torch.nn import functional as F
from vllm import _custom_ops as ops
-
-
-def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
- d = x.shape[-1] // 2
- return F.silu(x[..., :d]) * x[..., d:]
-
-
-def swigluoai_and_mul(
- x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
-) -> torch.Tensor:
- d = x.shape[-1] // 2
- gate, up = x[..., :d], x[..., d:]
- gate = gate.clamp(max=limit)
- up = up.clamp(min=-limit, max=limit)
- glu = gate * torch.sigmoid(alpha * gate)
- return (up + 1) * glu
+from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
def grouped_topk(
@@ -227,6 +212,11 @@ class CPUFusedMOE:
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
+ self.act_to_impl = {
+ "silu": SiluAndMul(),
+ "swigluoai": SwigluOAIAndMul(),
+ }
+
def __call__(
self,
layer: torch.nn.Module,
@@ -246,7 +236,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
- assert activation in {"silu", "swigluoai"}, f"{activation} is not supported."
+ assert activation in self.act_to_impl, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -283,10 +273,7 @@ class CPUFusedMOE:
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
- if activation == "swigluoai":
- gate_up = swigluoai_and_mul(gate_up)
- else:
- gate_up = silu_and_mul(gate_up)
+ gate_up = self.act_to_impl[activation].forward_native(gate_up)
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx
diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
index 06c9df317f7c7..fea9f49c04b89 100644
--- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
@@ -6,6 +6,7 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -27,6 +28,8 @@ logger = init_logger(__name__)
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
+logger = init_logger(__name__)
+
def dequant_fp8(
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
@@ -85,6 +88,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False,
+ global_to_physical: torch.Tensor | None = None,
+ physical_to_global: torch.Tensor | None = None,
+ local_expert_global_ids: torch.Tensor | None = None,
):
super().__init__()
@@ -97,6 +103,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers
+ topk_indices_dtype = self.topk_indices_dtype()
+
+ def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
+ if tensor is None or topk_indices_dtype is None:
+ return tensor
+ return tensor.to(dtype=topk_indices_dtype)
+
+ self.global_to_physical = _maybe_cast(global_to_physical)
+ self.physical_to_global = _maybe_cast(physical_to_global)
+ self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
+
# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup.
@@ -136,6 +153,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
+ def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
+ if self.global_to_physical is None:
+ return topk_ids
+ return self.global_to_physical[topk_ids]
+
+ def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
+ if self.local_expert_global_ids is None:
+ return expert_topk_ids
+ return self.local_expert_global_ids[expert_topk_ids]
+
def _do_quant(
self,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
@@ -163,16 +190,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
+ q_dtype = quant_config.quant_dtype
+
+ if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
+ logger.info_once(
+ "Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
+ "for ModelOptNvFp4FusedMoE."
+ )
+ q_dtype = None
+
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
- quant_config.quant_dtype,
+ q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
- if quant_config.quant_dtype is not None:
+ if q_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
@@ -226,9 +262,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
+ dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
- topk_ids,
+ dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
@@ -313,11 +350,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)
+ combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
# TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
- topk_ids,
+ combine_topk_ids,
combine_topk_weights,
handle,
async_finish=False,
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
new file mode 100644
index 0000000000000..2747ef04a3499
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
@@ -0,0 +1,346 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
+ TopKWeightAndReduceDelegate,
+)
+from vllm.utils.flashinfer import (
+ flashinfer_cutedsl_grouped_gemm_nt_masked,
+ has_flashinfer_cutedsl_grouped_gemm_nt_masked,
+ scaled_fp4_grouped_quantize,
+ silu_and_mul_scaled_nvfp4_experts_quantize,
+)
+
+logger = init_logger(__name__)
+
+
+def is_valid_flashinfer_cutedsl_fused_moe(
+ hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
+) -> bool:
+ """
+ Check if the given problem size is supported by the FlashInfer CuteDSL MoE
+ kernel.
+ """
+ if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
+ logger.debug_once(
+ "FlashInferCuteDSLExperts disabled: "
+ "flashinfer_cutedsl_fused_moe not available."
+ )
+ return False
+ # Data type checks
+ if (
+ w1.dtype != torch.uint8
+ or w2.dtype != torch.uint8
+ or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
+ ):
+ logger.debug_once(
+ "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
+ f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
+ f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
+ )
+ return False
+ return True
+
+
+class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
+ def __init__(
+ self,
+ out_dtype: torch.dtype,
+ quant_config: FusedMoEQuantConfig,
+ ):
+ super().__init__(quant_config)
+ assert quant_config.quant_dtype == "nvfp4", (
+ "Only nvfp4 quantization are currently supported."
+ )
+ self.out_dtype = out_dtype
+
+ @property
+ def activation_formats(
+ self,
+ ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ return (
+ mk.FusedMoEActivationFormat.BatchedExperts,
+ mk.FusedMoEActivationFormat.BatchedExperts,
+ )
+
+ def supports_expert_map(self) -> bool:
+ return False
+
+ def supports_chunking(self) -> bool:
+ # This refers to TP chunking; DP chunking is handled separately.
+ # TODO(shuw@nvidia.com): Set to False to be consistent with
+ # batched_deep_gemm_moe
+ return False
+
+ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
+ # Let PrepareAndFinalize::finalize() decide the impl.
+ return TopKWeightAndReduceDelegate()
+
+ def workspace_shapes(
+ self,
+ M: int,
+ N: int,
+ K: int,
+ topk: int,
+ global_num_experts: int,
+ local_num_experts: int,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ # We use global_num_experts due to how moe_align_block_size handles
+ # expert_maps.
+ """
+ Compute the shapes for the temporary and final outputs of the two gemms
+ and activation in the fused expert function. Since the gemms are
+ independent, the workspace for the first gemm can be shared with the
+ workspace for the last gemm.
+
+ Returns a tuple of:
+ - workspace13 shape tuple: must be large enough to hold the
+ result of either expert gemm.
+ - workspace2 shape tuple: must be large enough to hold the
+ result of the activation function.
+ - output shape tuple: must be exact size of the final gemm output.
+ - Workspace type: The dtype to use for the workspace tensors.
+ - Note: in order for activation chunking to work, the first dimension
+ of each tuple must be the number of tokens.
+ """
+ output_shape = (local_num_experts, M, K)
+ workspace2 = (local_num_experts, M, N)
+ workspace1 = output_shape
+ return (workspace1, workspace2, output_shape)
+
+ def apply(
+ self,
+ output: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ a2_scale: torch.Tensor | None, # Not used
+ workspace13: torch.Tensor | None,
+ workspace2: torch.Tensor | None,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ apply_router_weight_on_input: bool | None,
+ ):
+ assert self.quant_dtype == "nvfp4", (
+ "Only nvfp4 quantization are currently supported."
+ )
+ # Ensure w1_scale and w2_scale are not None before calling view
+ assert self.w1_scale is not None and self.w2_scale is not None, (
+ "w1_scale and w2_scale must not be None for FlashInferExperts"
+ )
+ assert expert_tokens_meta is not None
+ expert_num_tokens = expert_tokens_meta.expert_num_tokens
+ assert hidden_states.ndim == 3
+ assert self.w1_scale.ndim == 3
+ assert self.w2_scale.ndim == 3
+ flashinfer_cutedsl_moe_masked(
+ hidden_states=hidden_states,
+ input_global_scale=self.a1_gscale,
+ w1=w1,
+ w1_blockscale=self.w1_scale,
+ w1_alpha=self.g1_alphas,
+ w2=w2,
+ a2_global_scale=self.a2_gscale,
+ w2_blockscale=self.w2_scale,
+ w2_alpha=self.g2_alphas,
+ masked_m=expert_num_tokens,
+ workspace=workspace2,
+ out=output,
+ )
+
+
+def get_cute_dtype(input: torch.Tensor) -> str:
+ if input.dtype == torch.bfloat16:
+ return "bfloat16"
+ elif input.dtype == torch.float16:
+ return "float16"
+ elif input.dtype == torch.float32:
+ return "float32"
+ else:
+ raise ValueError(f"Unsupported cute dtype {input.dtype}")
+
+
+def flashinfer_cutedsl_moe_masked(
+ hidden_states: torch.Tensor,
+ input_global_scale: torch.Tensor,
+ w1: torch.Tensor,
+ w1_blockscale: torch.Tensor,
+ w1_alpha,
+ w2: torch.Tensor,
+ a2_global_scale: torch.Tensor,
+ w2_blockscale: torch.Tensor,
+ w2_alpha,
+ masked_m: torch.Tensor,
+ workspace: torch.Tensor,
+ out: torch.Tensor,
+):
+ """
+ Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
+ kernels.
+
+ Args:
+ hidden_states (torch.Tensor): [num_experts, m, k], bf16
+ input_global_scale (torch.Tensor): (l,)
+ w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
+ w1_blockscale (torch.Tensor): blockscale factors, e4m3,
+ w1_alpha (torch.Tensor): (l,)
+ w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
+ a2_global_scale (torch.Tensor): (l,)
+ w2_blockscale (torch.Tensor): blockscale factors, e4m3,
+ w2_alpha (torch.Tensor): (l,)
+ masked_m (torch.Tensor): Masked dimension indices
+ workspace (torch.Tensor): For gateup_output
+
+ Notes:
+ - Assumes max(masked_m) <= m.
+ """
+
+ # === Assertions on dtypes ===
+ assert input_global_scale.dtype == torch.float32, (
+ f"input_global_scale must be float32, got {input_global_scale.dtype}"
+ )
+ assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
+ assert w1_blockscale.dtype == torch.float8_e4m3fn, (
+ f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
+ )
+ assert w1_alpha.dtype == torch.float32, (
+ f"w1_alpha must be float32, got {w1_alpha.dtype}"
+ )
+ assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
+ assert a2_global_scale.dtype == torch.float32, (
+ f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
+ )
+ assert w2_blockscale.dtype == torch.float8_e4m3fn, (
+ f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
+ )
+ assert w2_alpha.dtype == torch.float32, (
+ f"w2_alpha must be float32, got {w2_alpha.dtype}"
+ )
+
+ # === Assertions on shapes ===
+ n = w2.shape[-1] * 2 # intermediate dimension
+ num_experts, m, k = hidden_states.shape
+
+ assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
+ assert w1.shape[-1] * 2 == k, (
+ f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
+ )
+ assert w2.shape[-2:] == (
+ k,
+ n // 2,
+ ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
+
+ assert input_global_scale.shape == (num_experts,), (
+ f"input_global_scale must be (l,), got {input_global_scale.shape}"
+ )
+ assert w1_alpha.shape == (num_experts,), (
+ f"w1_alpha must be (l,), got {w1_alpha.shape}"
+ )
+ assert a2_global_scale.shape == (num_experts,), (
+ f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
+ )
+ assert w2_alpha.shape == (num_experts,), (
+ f"w2_alpha must be (l,), got {w2_alpha.shape}"
+ )
+
+ aq, aq_sf = scaled_fp4_grouped_quantize(
+ hidden_states,
+ masked_m,
+ input_global_scale,
+ )
+
+ workspace = workspace.permute(1, 2, 0) # requirement of kernel
+ sf_vec_size = 16
+ assert aq_sf.dtype == torch.float8_e4m3fn
+ assert aq.dtype == torch.uint8
+ ab_dtype = "float4_e2m1fn"
+ sf_dtype = "float8_e4m3fn"
+
+ c_dtype = get_cute_dtype(hidden_states)
+
+ # Gemm1
+ flashinfer_cutedsl_grouped_gemm_nt_masked(
+ (aq, aq_sf),
+ (w1.permute(1, 2, 0), w1_blockscale),
+ workspace,
+ masked_m,
+ ab_dtype=ab_dtype,
+ sf_dtype=sf_dtype,
+ c_dtype=c_dtype,
+ sf_vec_size=sf_vec_size,
+ alpha=w1_alpha.view(1, 1, num_experts),
+ alpha_dtype=get_cute_dtype(w1_alpha),
+ ) # in logical [m, n, l]
+
+ # SILU and quantization
+ diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
+ workspace.permute(2, 0, 1),
+ masked_m,
+ a2_global_scale,
+ )
+
+ # Gemm2
+ out = out.permute(1, 2, 0) # requirement of kernel
+ flashinfer_cutedsl_grouped_gemm_nt_masked(
+ (diq, diq_sf),
+ (w2.permute(1, 2, 0), w2_blockscale),
+ out,
+ masked_m,
+ ab_dtype=ab_dtype,
+ sf_dtype=sf_dtype,
+ c_dtype=c_dtype,
+ sf_vec_size=sf_vec_size,
+ alpha=w2_alpha.view(1, 1, num_experts),
+ alpha_dtype=get_cute_dtype(w2_alpha),
+ ) # in logical [m, k, l]
+ out = out.permute(2, 0, 1)
+
+
+def flashinfer_cutedsl_moe_fp4(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ quant_config: FusedMoEQuantConfig,
+ inplace: bool = False,
+ activation: str = "silu",
+ global_num_experts: int = -1,
+ expert_map: torch.Tensor | None = None,
+ apply_router_weight_on_input: bool = False,
+) -> torch.Tensor:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
+ create_flashinfer_prepare_finalize,
+ )
+
+ fused_experts = mk.FusedMoEModularKernel(
+ create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
+ FlashInferCuteDSLExperts(
+ out_dtype=hidden_states.dtype,
+ quant_config=quant_config,
+ ),
+ )
+
+ return fused_experts(
+ hidden_states=hidden_states,
+ w1=w1,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=inplace,
+ activation=activation,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ )
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 2e042d85fcfcf..df208eae2e71c 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -872,8 +872,10 @@ def get_moe_configs(
for config_file_path in config_file_paths:
if os.path.exists(config_file_path):
with open(config_file_path) as f:
- logger.info(
- "Using configuration from %s for MoE layer.", config_file_path
+ logger.info_once(
+ "Using configuration from %s for MoE layer.",
+ config_file_path,
+ scope="global",
)
# If a configuration has been found, return it
tuned_config = json.load(f)
@@ -1246,7 +1248,6 @@ def eplb_map_to_physical_and_record(
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
- indices_type: torch.dtype | None = None,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
@@ -1260,7 +1261,6 @@ def eplb_map_to_physical_and_record(
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
- indices_type: The indices type.
Returns:
The physical expert ids.
@@ -1310,9 +1310,6 @@ def eplb_map_to_physical_and_record(
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
-
- if indices_type is not None:
- topk_ids = topk_ids.to(dtype=indices_type)
return topk_ids
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index 87f8c8d75a9b5..ef7090c349fc6 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -50,10 +50,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"""
return False
- def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
+ def maybe_make_prepare_finalize(
+ self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize
- return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
+ return maybe_make_prepare_finalize(
+ self.moe, self.moe_quant_config, routing_tables
+ )
def select_gemm_impl(
self,
@@ -85,10 +90,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def allow_inplace(self) -> bool:
return False
+ @property
+ def method_name(self) -> str:
+ return self.__class__.__name__
+
@abstractmethod
def apply(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
index 43974ba917e42..c23c41df226f0 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
@@ -50,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
+ getattr(moe_layer, "shared_experts_stream", None),
),
)
@@ -65,6 +66,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
+ @property
+ def method_name(self) -> str:
+ return self.old_quant_method.method_name
+
def create_weights(
self,
layer: torch.nn.Module,
@@ -83,7 +88,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -104,42 +109,9 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- # Is getattr needed?
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
- if enable_eplb:
- if self.supports_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
- else:
- raise NotImplementedError(
- "EPLB is not supported for "
- f"{self.old_quant_method.__class__.__name__}."
- )
-
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
)
result = self.fused_experts(
@@ -155,7 +127,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
expert_map=None if self.disable_expert_map else expert_map,
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 023132acfed3f..bb30f1292a5fa 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial
-from typing import Literal, get_args, overload
+from typing import Literal, cast, get_args, overload
import torch
import torch.nn.functional as F
@@ -18,6 +18,7 @@ from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
get_ep_group,
+ get_pcp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
@@ -67,7 +68,6 @@ else:
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
- indices_type: torch.dtype | None,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
@@ -192,6 +192,42 @@ def determine_expert_map(
return (local_num_experts, expert_map, expert_mask)
+def determine_expert_placement_strategy(
+ expert_placement_strategy: ExpertPlacementStrategy,
+ moe_parallel_config: FusedMoEParallelConfig,
+ num_expert_group: int | None,
+ num_redundant_experts: int,
+ enable_eplb: bool,
+) -> ExpertPlacementStrategy:
+ if expert_placement_strategy == "round_robin":
+ round_robin_supported = (
+ (num_expert_group is not None and num_expert_group > 1)
+ and num_redundant_experts == 0
+ and not enable_eplb
+ )
+
+ if not round_robin_supported:
+ logger.warning(
+ "Round-robin expert placement is only supported for "
+ "models with multiple expert groups and no redundant "
+ "experts. Falling back to linear expert placement."
+ )
+ return "linear"
+ if (
+ moe_parallel_config.use_all2all_kernels
+ and not moe_parallel_config.use_deepep_ll_kernels
+ ):
+ logger.warning(
+ "Round-robin expert placement currently only supports "
+ "the DeepEP low-latency backend, but '%s' was configured. "
+ "Falling back to linear expert placement.",
+ moe_parallel_config.all2all_backend,
+ )
+ return "linear"
+
+ return expert_placement_strategy
+
+
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
"""
Compresses the expert map by removing any -1 entries.
@@ -307,6 +343,7 @@ class FusedMoE(CustomOp):
tp_size: int | None = None,
ep_size: int | None = None,
dp_size: int | None = None,
+ pcp_size: int | None = None,
prefix: str = "",
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
@@ -335,8 +372,8 @@ class FusedMoE(CustomOp):
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
- # TODO(rob): enable shared expert overlap with non-cuda.
- # aux_stream() returns None on non-cuda platforms.
+ # TODO(rob): enable shared expert overlap with non-cuda-alike.
+ # aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
@@ -362,12 +399,14 @@ class FusedMoE(CustomOp):
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
+ pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size
self.is_sequence_parallel = is_sequence_parallel
self.sp_size = tp_size_ if is_sequence_parallel else 1
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=tp_size_,
+ pcp_size_=pcp_size_,
dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config,
)
@@ -400,6 +439,9 @@ class FusedMoE(CustomOp):
self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
+ self.expert_placement_strategy: ExpertPlacementStrategy = (
+ vllm_config.parallel_config.expert_placement_strategy
+ )
# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
@@ -433,38 +475,27 @@ class FusedMoE(CustomOp):
"Redundant experts are only supported with EPLB."
)
- expert_placement_strategy = (
- vllm_config.parallel_config.expert_placement_strategy
+ self.expert_placement_strategy = determine_expert_placement_strategy(
+ expert_placement_strategy=self.expert_placement_strategy,
+ moe_parallel_config=self.moe_parallel_config,
+ num_expert_group=num_expert_group,
+ num_redundant_experts=num_redundant_experts,
+ enable_eplb=self.enable_eplb,
)
- if expert_placement_strategy == "round_robin":
- # TODO(Bruce): will support round robin expert placement with
- # EPLB enabled in the future.
- round_robin_supported = (
- (num_expert_group is not None and num_expert_group > 1)
- and num_redundant_experts == 0
- and not self.enable_eplb
- )
-
- if not round_robin_supported:
- logger.warning(
- "Round-robin expert placement is only supported for "
- "models with multiple expert groups and no redundant "
- "experts. Falling back to linear expert placement."
- )
- expert_placement_strategy = "linear"
self.expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
- expert_placement_strategy=expert_placement_strategy,
+ expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
+ self._maybe_init_expert_routing_tables()
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
@@ -472,7 +503,7 @@ class FusedMoE(CustomOp):
" %s.",
self.ep_rank,
self.ep_size,
- expert_placement_strategy,
+ self.expert_placement_strategy,
self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map),
@@ -542,6 +573,9 @@ class FusedMoE(CustomOp):
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
)
+ self.moe_config_use_flashinfer_cutlass_kernels = (
+ self.moe_config.use_flashinfer_cutlass_kernels
+ )
self.quant_config = quant_config
@@ -621,7 +655,12 @@ class FusedMoE(CustomOp):
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
self.ensure_moe_quant_config_init()
- prepare_finalize = self.quant_method.maybe_make_prepare_finalize()
+ # routing_tables only needed for round-robin expert placement with
+ # DeepEP all2all backend.
+ routing_tables = self._maybe_init_expert_routing_tables()
+ prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
+ routing_tables=routing_tables
+ )
if prepare_finalize is not None:
logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
@@ -646,6 +685,10 @@ class FusedMoE(CustomOp):
def dp_size(self):
return self.moe_parallel_config.dp_size
+ @property
+ def pcp_size(self):
+ return self.moe_parallel_config.pcp_size
+
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@@ -658,6 +701,10 @@ class FusedMoE(CustomOp):
def dp_rank(self):
return self.moe_parallel_config.dp_rank
+ @property
+ def pcp_rank(self):
+ return self.moe_parallel_config.pcp_rank
+
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@@ -683,7 +730,7 @@ class FusedMoE(CustomOp):
return (
self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
- and self.moe_config.use_flashinfer_cutlass_kernels
+ and self.moe_config_use_flashinfer_cutlass_kernels
)
@property
@@ -703,6 +750,84 @@ class FusedMoE(CustomOp):
# By default, router/gate is called before FusedMoE forward pass
return False
+ def _maybe_init_expert_routing_tables(
+ self,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
+ # Currently routing_tables only needed for round-robin expert placement
+ # with DeepEP-ll all2all backend.
+ if (
+ self.expert_placement_strategy != "round_robin"
+ or not self.use_deepep_ll_kernels
+ ):
+ return None
+
+ if hasattr(self, "expert_global_to_physical"):
+ return cast(
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ (
+ self.expert_global_to_physical,
+ self.expert_physical_to_global,
+ self.expert_local_to_global,
+ ),
+ )
+
+ if self.expert_map is None:
+ return None
+
+ routing_tables = self.ensure_round_robin_expert_routing_tables(
+ global_num_experts=self.global_num_experts,
+ ep_size=self.ep_size,
+ ep_rank=self.ep_rank,
+ local_num_experts=self.local_num_experts,
+ device=self.expert_map.device,
+ )
+
+ global_to_physical, physical_to_global, local_global = routing_tables
+ self.register_buffer("expert_global_to_physical", global_to_physical)
+ self.register_buffer("expert_physical_to_global", physical_to_global)
+ self.register_buffer("expert_local_to_global", local_global)
+
+ return routing_tables
+
+ @staticmethod
+ def ensure_round_robin_expert_routing_tables(
+ global_num_experts: int,
+ ep_size: int,
+ ep_rank: int,
+ local_num_experts: int,
+ device: torch.device | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ device_kwargs = {"device": device} if device is not None else {}
+ global_indices = torch.arange(
+ global_num_experts, dtype=torch.long, **device_kwargs
+ )
+ owner = torch.remainder(global_indices, ep_size)
+ local_index = torch.div(global_indices, ep_size, rounding_mode="floor")
+ base = global_num_experts // ep_size
+ remainder = global_num_experts % ep_size
+ physical_offset = owner * base
+ if remainder > 0:
+ remainder_tensor = torch.tensor(
+ remainder, dtype=torch.long, **device_kwargs
+ )
+ physical_offset = physical_offset + torch.minimum(owner, remainder_tensor)
+
+ global_to_physical = physical_offset + local_index
+ physical_to_global = torch.empty_like(global_to_physical)
+ physical_to_global[global_to_physical] = global_indices
+
+ local_global = torch.arange(
+ ep_rank,
+ global_num_experts,
+ ep_size,
+ dtype=torch.long,
+ **device_kwargs,
+ )
+ if local_global.numel() != local_num_experts:
+ local_global = local_global[:local_num_experts]
+
+ return (global_to_physical, physical_to_global, local_global)
+
def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
@@ -711,18 +836,59 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
+ expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
+ self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(),
dp_size=get_dp_group().world_size,
)
+ def _maybe_setup_shared_experts_stream(
+ self,
+ hidden_states: torch.Tensor,
+ has_separate_shared_experts: bool,
+ use_chunked_impl: bool,
+ ) -> tuple[bool, torch.Tensor | None]:
+ use_shared_experts_stream = (
+ has_separate_shared_experts
+ and not use_chunked_impl
+ and self.shared_experts_stream is not None
+ and (
+ hidden_states.shape[0]
+ <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
+ )
+ )
+
+ hidden_states_clone: torch.Tensor | None = None
+ if use_shared_experts_stream:
+ assert self.shared_experts_stream is not None
+
+ # Clone BEFORE switching streams to avoid race condition
+ # where routed_expert kernel may mutate hidden_states.
+ hidden_states_clone = hidden_states.clone()
+
+ # Record that the clone will be used by shared_experts_stream
+ # to avoid gc issue from deallocation of hidden_states_clone
+ # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
+ # NOTE: We dont need shared_output.record_stream(current_stream())
+ # because we synch the streams before using shared_output.
+ hidden_states_clone.record_stream(self.shared_experts_stream)
+
+ # Mark sync start point for the separate shared experts
+ # stream here since we want to run in parallel with the
+ # router/gate (next op below)
+ assert self.shared_experts_stream is not None
+ self.shared_experts_stream.wait_stream(current_stream())
+
+ return use_shared_experts_stream, hidden_states_clone
+
def _load_per_tensor_weight_scale(
self,
shard_id: str,
@@ -1225,7 +1391,48 @@ class FusedMoE(CustomOp):
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]:
+ def _maybe_make_contiguous(
+ name: str, p: torch.nn.Parameter
+ ) -> torch.nn.Parameter:
+ """
+ In some cases, the last 2 dimensions (the non-expert dimensions)
+ of the weight scale tensor are transposed. This function
+ transforms the tensor (view update) so the tensor is contiguous().
+ Example: A non-contiguous scale tensor,
+ `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to
+ `x_` of shape (E, 16, 32) and stride (512, 32, 1).
+ Note that we specifically use torch.transpose() so `x_` refers
+ to the same underlying memory. The tensors `x` and `x_`, pointing
+ to the same underlying memory make this transformation safe in the
+ context of EPLB. i.e. It is the same memory and just the view
+ is different.
+ Note: This function handles the "weight_scale" tensors specifically.
+ This could however be generalized to handle similar tensors.
+ """
+ if p.ndim != 3:
+ return p
+ if p.is_contiguous():
+ # Already contiguous. do nothing.
+ return p
+ # p is non-contiguous. We only handle the case where the last 2
+ # dimensions of the scales tensor is transposed. We can handle
+ # other cases when they become relevant.
+ is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1
+ if "weight_scale" not in name or not is_transposed_12:
+ # do nothing.
+ return p
+
+ # Do not update the layer paramater as the layer's MoE operations would
+ # expect the parameter's tensor to the same shape / stride. Instead,
+ # make a new torch.nn.Parameter that is used just in the context of
+ # EPLB.
+ return torch.nn.Parameter(
+ torch.transpose(p.data, 1, 2), requires_grad=False
+ )
+
weights = list(self.named_parameters())
+ weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
+
assert all(
weight.is_contiguous()
for name, weight in weights
@@ -1303,30 +1510,11 @@ class FusedMoE(CustomOp):
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
- @staticmethod
def select_experts(
+ self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
- top_k: int,
- use_grouped_topk: bool,
- renormalize: bool,
- topk_group: int | None = None,
- num_expert_group: int | None = None,
- custom_routing_function: Callable | None = None,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0,
- e_score_correction_bias: torch.Tensor | None = None,
- indices_type: torch.dtype | None = None,
- enable_eplb: bool = False,
- expert_map: torch.Tensor | None = None,
- expert_load_view: torch.Tensor | None = None,
- logical_to_physical_map: torch.Tensor | None = None,
- logical_replica_count: torch.Tensor | None = None,
- global_num_experts: int | None = None,
- zero_expert_num: int | None = None,
- zero_expert_type: str | None = None,
- num_fused_shared_experts: int = 0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
@@ -1345,6 +1533,27 @@ class FusedMoE(CustomOp):
fused_topk_bias,
)
+ if self.enable_eplb:
+ if self.quant_method.supports_eplb:
+ if self.expert_load_view is None:
+ raise ValueError(
+ "enable_eplb=True requiere expert_load_view != None"
+ )
+ if self.logical_to_physical_map is None:
+ raise ValueError(
+ "enable_eplb=True requiere logical_to_physical_map != None"
+ )
+ if self.logical_replica_count is None:
+ raise ValueError(
+ "enable_eplb=True requiere logical_replica_count != None"
+ )
+ else:
+ raise NotImplementedError(
+ f"EPLB is not supported for {self.quant_method.method_name}."
+ )
+
+ indices_type = self.quant_method.topk_indices_dtype
+
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
@@ -1352,20 +1561,20 @@ class FusedMoE(CustomOp):
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
- top_k=top_k,
+ top_k=self.top_k,
indices_type=indices_type,
)
# DeepSeekv2 uses grouped_top_k
- elif use_grouped_topk:
- assert topk_group is not None
- assert num_expert_group is not None
+ elif self.use_grouped_topk:
+ assert self.topk_group is not None
+ assert self.num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
- assert num_fused_shared_experts == 0
+ assert self.num_fused_shared_experts == 0
grouped_topk_impl = partial(
rocm_aiter_grouped_topk,
- num_fused_shared_experts=num_fused_shared_experts,
+ num_fused_shared_experts=self.num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
@@ -1373,71 +1582,65 @@ class FusedMoE(CustomOp):
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
+ topk=self.top_k,
+ renormalize=self.renormalize,
+ num_expert_group=self.num_expert_group,
+ topk_group=self.topk_group,
+ scoring_func=self.scoring_func,
+ routed_scaling_factor=self.routed_scaling_factor,
+ e_score_correction_bias=self.e_score_correction_bias,
)
- if indices_type is not None:
- topk_ids = topk_ids.to(dtype=indices_type)
- elif e_score_correction_bias is not None:
+ elif self.e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
- e_score_correction_bias=e_score_correction_bias.data,
- topk=top_k,
- renormalize=renormalize,
+ e_score_correction_bias=self.e_score_correction_bias.data,
+ topk=self.top_k,
+ renormalize=self.renormalize,
)
- if routed_scaling_factor is not None:
- topk_weights *= routed_scaling_factor
- elif custom_routing_function is None:
+ if self.routed_scaling_factor != 1.0:
+ topk_weights *= self.routed_scaling_factor
+ elif self.custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
+ topk=self.top_k,
+ renormalize=self.renormalize,
indices_type=indices_type,
)
else:
- topk_weights, topk_ids = custom_routing_function(
+ topk_weights, topk_ids = self.custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
+ topk=self.top_k,
+ renormalize=self.renormalize,
)
- if indices_type is not None:
- topk_ids = topk_ids.to(dtype=indices_type)
-
- if enable_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
+ if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- indices_type=indices_type,
+ expert_load_view=self.expert_load_view,
+ logical_to_physical_map=self.logical_to_physical_map,
+ logical_replica_count=self.logical_replica_count,
)
+ if (indices_type is not None) and topk_ids.dtype != indices_type:
+ topk_ids = topk_ids.to(dtype=indices_type)
+
assert topk_ids.dtype == indices_type or indices_type is None
# Compute zero expert result if needed
if (
- zero_expert_num is not None
- and zero_expert_num > 0
- and zero_expert_type is not None
- and global_num_experts is not None
+ self.zero_expert_num is not None
+ and self.zero_expert_num > 0
+ and self.zero_expert_type is not None
+ and self.global_num_experts is not None
):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
- num_experts=global_num_experts,
- zero_expert_type=zero_expert_type,
+ num_experts=self.global_num_experts,
+ zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states,
)
else:
@@ -1487,6 +1690,10 @@ class FusedMoE(CustomOp):
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
+ # Slice before all_reduce to enable possible fusion
+ if self.hidden_size != og_hidden_states:
+ states = states[..., :og_hidden_states]
+
if (
not self.is_sequence_parallel
and not self.use_dp_chunking
@@ -1509,11 +1716,12 @@ class FusedMoE(CustomOp):
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(fused_output, tuple)
fused_output, zero_expert_result = fused_output
- return (reduce_output(fused_output) + zero_expert_result)[
- ..., :og_hidden_states
- ]
+ return (
+ reduce_output(fused_output)
+ + zero_expert_result[..., :og_hidden_states]
+ )
else:
- return reduce_output(fused_output)[..., :og_hidden_states]
+ return reduce_output(fused_output)
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -1526,8 +1734,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, self.layer_name
)
return (
- reduce_output(shared_output)[..., :og_hidden_states],
- reduce_output(fused_output)[..., :og_hidden_states],
+ reduce_output(shared_output),
+ reduce_output(fused_output),
)
def forward_cuda(
@@ -1694,36 +1902,12 @@ class FusedMoE(CustomOp):
use_chunked_impl = self.use_dp_chunking
- use_shared_experts_stream = (
- has_separate_shared_experts
- and not use_chunked_impl
- and self.shared_experts_stream is not None
- and (
- hidden_states.shape[0]
- <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
+ use_shared_experts_stream, hidden_states_clone = (
+ self._maybe_setup_shared_experts_stream(
+ hidden_states, has_separate_shared_experts, use_chunked_impl
)
)
- if use_shared_experts_stream:
- assert self.shared_experts_stream is not None
-
- # Clone BEFORE switching streams to avoid race condition
- # where routed_expert kernel may mutate hidden_states.
- hidden_states_clone = hidden_states.clone()
-
- # Record that the clone will be used by shared_experts_stream
- # to avoid gc issue from deallocation of hidden_states_clone
- # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
- # NOTE: We dont need shared_output.record_stream(current_stream())
- # because we synch the streams before using shared_output.
- hidden_states_clone.record_stream(self.shared_experts_stream)
-
- # Mark sync start point for the separate shared experts
- # stream here since we want to run in parallel with the
- # router/gate (next op below)
- assert self.shared_experts_stream is not None
- self.shared_experts_stream.wait_stream(current_stream())
-
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
@@ -1752,6 +1936,24 @@ class FusedMoE(CustomOp):
hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel
)
+ # Run shared experts before matrix multiply.
+ # because matrix multiply maybe modify the hidden_states.
+ if has_separate_shared_experts and not use_shared_experts_stream:
+ assert self.shared_experts is not None
+ shared_output = self.shared_experts(hidden_states)
+
+ # NOTE: Similar with DP, PCP also needs dispatch and combine. For
+ # simplicity, AgRsAll2All was added separately for PCP here. Maybe
+ # we should modify All2AllManager abstract to better support PCP.
+ if self.pcp_size > 1:
+ hidden_states = get_pcp_group().all_gather(
+ hidden_states,
+ dim=0,
+ )
+ router_logits = get_pcp_group().all_gather(
+ router_logits,
+ dim=0,
+ )
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
@@ -1795,8 +1997,6 @@ class FusedMoE(CustomOp):
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream)
- else:
- shared_output = self.shared_experts(hidden_states)
final_hidden_states = (
shared_output,
@@ -1809,6 +2009,13 @@ class FusedMoE(CustomOp):
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)
+
+ if self.pcp_size > 1:
+ states = get_pcp_group().reduce_scatter(
+ states,
+ dim=0,
+ )
+
return states
if self.shared_experts is not None:
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index 093affe51f503..b2af58cdca887 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -10,12 +10,16 @@ from typing import final
import torch
import vllm.envs as envs
+from vllm.config import get_current_vllm_config
+from vllm.forward_context import get_forward_context, is_forward_context_available
+from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
count_expert_num_tokens,
disable_inplace,
)
+from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
@@ -25,6 +29,8 @@ from vllm.v1.worker.ubatching import (
dbo_yield,
)
+logger = init_logger(__name__)
+
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
@@ -709,11 +715,13 @@ class FusedMoEModularKernel(torch.nn.Module):
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
+ shared_experts_stream: torch.cuda.Stream | None = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
+ self.shared_experts_stream = shared_experts_stream
self._post_init_setup()
assert (
@@ -795,6 +803,42 @@ class FusedMoEModularKernel(torch.nn.Module):
buffers = self.shared_buffers[ubatch_idx]
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
+ # Force worst-case allocation in profiling run for
+ # "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
+ # by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
+ # DP+EP due to the random token routing.
+ is_profile_run = (
+ is_forward_context_available()
+ and get_forward_context().attn_metadata is None
+ )
+ if is_profile_run and self.fused_experts.supports_chunking():
+ parallel_config = get_current_vllm_config().parallel_config
+ is_dp_ep = (
+ parallel_config.data_parallel_size > 1
+ and parallel_config.enable_expert_parallel
+ )
+ if is_dp_ep:
+ max_workspace_13, max_workspace_2, max_fused_out_shape = (
+ self.fused_experts.workspace_shapes(
+ envs.VLLM_FUSED_MOE_CHUNK_SIZE,
+ N,
+ K,
+ top_k,
+ global_num_experts,
+ local_num_experts,
+ expert_tokens_meta,
+ )
+ )
+ buffers.workspace13.get(
+ max_workspace_13, device=device, dtype=workspace_dtype
+ )
+ buffers.workspace2.get(
+ max_workspace_2, device=device, dtype=workspace_dtype
+ )
+ buffers.fused_out.get(
+ max_fused_out_shape, device=device, dtype=workspace_dtype
+ )
+
# Get intermediate workspace shapes based off the chunked M size.
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
M_chunk,
@@ -890,6 +934,34 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
)
+ def _maybe_setup_shared_experts_stream(
+ self, hidden_states: torch.Tensor
+ ) -> tuple[bool, torch.Tensor | None]:
+ # decide whether to run shared experts on a separate CUDA stream to
+ # overlap with the main fused MoE kernel.
+ use_shared_experts_stream = (
+ self.shared_experts is not None
+ and self.shared_experts_stream is not None
+ and hidden_states.is_cuda
+ and (
+ hidden_states.shape[0]
+ <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
+ )
+ )
+
+ hidden_states_clone: torch.Tensor | None = None
+ if use_shared_experts_stream and self.shared_experts_stream is not None:
+ # TODO: Optimize this (complicated)
+ # Note: this clone adds overhead but is required
+ # for correctness with multiple CUDA streams and CUDA graph capture.
+ hidden_states_clone = hidden_states.clone()
+ # record that the clone will be used by the separate stream so its
+ # lifetime is correctly tracked.
+ hidden_states_clone.record_stream(self.shared_experts_stream)
+ self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
+
+ return use_shared_experts_stream, hidden_states_clone
+
def _prepare(
self,
hidden_states: torch.Tensor,
@@ -1077,12 +1149,30 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
+ hidden_states_clone: torch.Tensor | None = None,
+ use_shared_experts_stream: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
- shared_output: torch.Tensor | None = None
+
+ def maybe_run_shared_experts() -> torch.Tensor | None:
+ if self.shared_experts is None:
+ return None
+
+ if (
+ not use_shared_experts_stream
+ or self.shared_experts_stream is not None
+ and (not hidden_states.is_cuda or not torch.cuda.is_available())
+ ):
+ # fall back to running on the current stream
+ return self.shared_experts(hidden_states)
+
+ assert hidden_states_clone is not None
+ # launch shared experts on the dedicated stream.
+ with torch.cuda.stream(self.shared_experts_stream):
+ return self.shared_experts(hidden_states_clone)
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
@@ -1095,8 +1185,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
- if self.shared_experts is not None:
- shared_output = self.shared_experts(hidden_states)
+ shared_output = maybe_run_shared_experts()
else:
finalize_ret = self.prepare_finalize.finalize_async(
output,
@@ -1107,8 +1196,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
- if self.shared_experts is not None:
- shared_output = self.shared_experts(hidden_states)
+ shared_output = maybe_run_shared_experts()
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
@@ -1131,12 +1219,28 @@ class FusedMoEModularKernel(torch.nn.Module):
receiver()
+ self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
+
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output
+ def _wait_for_shared_experts_stream(
+ self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
+ ) -> None:
+ # ensure that any work enqueued on the shared_experts_stream is
+ # completed before the shared_output tensor is consumed
+ if (
+ self.shared_experts is not None
+ and use_shared_experts_stream
+ and self.shared_experts_stream is not None
+ and hidden_states.is_cuda
+ and current_platform.is_cuda()
+ ):
+ torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -1183,6 +1287,10 @@ class FusedMoEModularKernel(torch.nn.Module):
else:
output = torch.zeros_like(hidden_states)
+ use_shared_experts_stream, hidden_states_clone = (
+ self._maybe_setup_shared_experts_stream(hidden_states)
+ )
+
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
@@ -1219,4 +1327,6 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
+ hidden_states_clone=hidden_states_clone,
+ use_shared_experts_stream=use_shared_experts_stream,
)
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
index 9bb976fb9ec93..e27e2eb32da0f 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
@@ -45,7 +45,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
- a1.mul_(topk_weights.to(a1.dtype))
+ # Note: do not use inplace for shared experts overlap
+ a1 = a1 * topk_weights.to(a1.dtype)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index 2e0376553b913..48e5a8907f926 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -108,11 +108,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool:
return True
- def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
+ def maybe_make_prepare_finalize(
+ self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
else:
- return super().maybe_make_prepare_finalize()
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -328,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -349,31 +352,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
- num_fused_shared_experts=layer.num_fused_shared_experts,
)
if self.rocm_aiter_moe_enabled:
@@ -412,7 +393,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
@@ -422,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -471,7 +452,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_xpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
@@ -512,7 +493,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu(
self,
- layer: torch.nn.Module,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py
index 2e7500bac7188..27cc3884517f9 100644
--- a/vllm/model_executor/layers/kda.py
+++ b/vllm/model_executor/layers/kda.py
@@ -5,7 +5,6 @@ import torch
from einops import rearrange
from torch import nn
-from vllm.attention import AttentionBackend
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (
@@ -83,12 +82,7 @@ direct_register_custom_op(
class KimiDeltaAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
- return "linear_attention"
-
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
-
- return GDNAttentionBackend
+ return "gdn_attention"
def get_state_dtype(
self,
diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py
index 99853680eac6c..ffccdc12241cb 100644
--- a/vllm/model_executor/layers/lightning_attn.py
+++ b/vllm/model_executor/layers/lightning_attn.py
@@ -198,7 +198,7 @@ def _fwd_kv_parallel(
)
# Load the decay factors for the current head and block
- k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]
+ k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
kv_index = tl.arange(0, CBLOCK)
@@ -228,6 +228,12 @@ def _fwd_kv_parallel(
# Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr)
+
+ # NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
+ # Please don't move it back until issue is resolved.
+ # Issue: https://github.com/ROCm/triton/issues/907
+ k_decay = k_decay[None, :]
+
kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block
diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py
index e68b09b4d81f5..aa919d6fdc35c 100644
--- a/vllm/model_executor/layers/mamba/abstract.py
+++ b/vllm/model_executor/layers/mamba/abstract.py
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
import torch
+from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
@@ -38,11 +39,6 @@ class MambaBase(AttentionLayerBase):
def mamba_type(self) -> str:
pass
- @abstractmethod
- def get_attn_backend(self) -> type["AttentionBackend"]:
- """Get the attention backend class for this Mamba layer."""
- pass
-
@abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass
@@ -69,3 +65,7 @@ class MambaBase(AttentionLayerBase):
else 0
),
)
+
+ def get_attn_backend(self) -> type["AttentionBackend"]:
+ """Get the attention backend class for this Mamba layer."""
+ return get_mamba_attn_backend(self.mamba_type)
diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py
index 0a2742ff49a44..d85b3e61c5d61 100644
--- a/vllm/model_executor/layers/mamba/linear_attn.py
+++ b/vllm/model_executor/layers/mamba/linear_attn.py
@@ -2,12 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
-
-from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@@ -37,9 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
-
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
@@ -123,11 +114,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str:
return "linear_attention"
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
-
- return LinearAttentionBackend
-
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py
index b6345b8af7f0a..90e520e244416 100644
--- a/vllm/model_executor/layers/mamba/mamba_mixer.py
+++ b/vllm/model_executor/layers/mamba/mamba_mixer.py
@@ -1,10 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import TYPE_CHECKING, NamedTuple
-
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
+from typing import NamedTuple
import torch
from torch import nn
@@ -452,11 +449,6 @@ class MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba1"
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
-
- return Mamba1AttentionBackend
-
def _time_proj_bias(self) -> torch.Tensor | None:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py
index 57313990b8206..0ea5805305eda 100644
--- a/vllm/model_executor/layers/mamba/mamba_mixer2.py
+++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py
@@ -1,10 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@@ -594,7 +590,6 @@ class MambaMixer2(MambaBase, CustomOp):
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
return hidden_states
- # NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
@@ -908,11 +903,6 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba2"
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
-
- return Mamba2AttentionBackend
-
def mamba_mixer2(
projected_states: torch.Tensor,
diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py
index 04efa8a8b3734..0bbad17d7ebc7 100644
--- a/vllm/model_executor/layers/mamba/short_conv.py
+++ b/vllm/model_executor/layers/mamba/short_conv.py
@@ -1,10 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
import torch
@@ -232,11 +228,6 @@ class ShortConv(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "short_conv"
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
-
- return ShortConvAttentionBackend
-
def short_conv(
hidden_states: torch.Tensor,
diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py
index c4c44b83ae6bf..6ebfa47a9dc3f 100644
--- a/vllm/model_executor/layers/mla.py
+++ b/vllm/model_executor/layers/mla.py
@@ -24,6 +24,7 @@ class MLAModules:
q_b_proj: torch.nn.Module | None
q_proj: torch.nn.Module | None
indexer: torch.nn.Module | None
+ indexer_rotary_emb: torch.nn.Module | None
is_sparse: bool
topk_indices_buffer: torch.Tensor | None
@@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
+ self.indexer_rope_emb = mla_modules.indexer_rotary_emb
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
@@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
)
if self.indexer and self.is_sparse:
- _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
+ _topk_indices = self.indexer(
+ hidden_states, q_c, positions, self.indexer_rope_emb
+ )
attn_out = self.mla_attn(
q,
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 3f6ea68072b40..66945e2d2a7c8 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -597,7 +597,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -618,24 +618,11 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.")
-
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py
index e5a741e639ad9..1e57fa218b797 100644
--- a/vllm/model_executor/layers/quantization/bitsandbytes.py
+++ b/vllm/model_executor/layers/quantization/bitsandbytes.py
@@ -495,7 +495,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -518,25 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `BitsAndBytesMoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
+ # TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index 06ee96d55419c..149e4419c64a4 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -8,6 +8,7 @@ from enum import Enum
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
+from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -50,9 +51,15 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
+ flashinfer_trtllm_fp4_moe,
+ prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ FlashinferMoeBackend,
+ get_flashinfer_moe_backend,
+)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major,
requant_weight_ue8m0_inplace,
@@ -193,6 +200,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
+ self.flashinfer_moe_backend = None
+ if self.allow_flashinfer:
+ self.flashinfer_moe_backend = get_flashinfer_moe_backend()
+ logger.info_once(
+ f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
+ " for CompressedTensorsW4A4MoeMethod."
+ )
def create_weights(
self,
@@ -344,21 +358,20 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
return
-
- # swizzle weight scales
- layer.w13_weight_scale = torch.nn.Parameter(
- swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
- )
-
- layer.w2_weight_scale = torch.nn.Parameter(
- swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
- )
-
# w13
- w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to(
- torch.float32
- )
-
+ if (
+ self.allow_flashinfer
+ and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
+ w13_input_global_scale = (
+ layer.w13_input_global_scale.min()
+ .to(torch.float32)
+ .expand(layer.num_experts)
+ )
+ else:
+ w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
+ torch.float32
+ )
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False,
@@ -369,22 +382,95 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
)
# w2
+ if (
+ self.allow_flashinfer
+ and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
+ w2_input_global_scale = (
+ layer.w2_input_global_scale.min()
+ .to(torch.float32)
+ .expand(layer.num_experts)
+ )
+ else:
+ w2_input_global_scale = layer.w2_input_global_scale
+
layer.g2_alphas = torch.nn.Parameter(
- ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
- torch.float32
- ),
+ ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
layer.w2_input_scale_quant = torch.nn.Parameter(
- (layer.w2_input_global_scale), requires_grad=False
+ (w2_input_global_scale), requires_grad=False
)
- def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.use_marlin:
+ # TensorRT-LLM specific processing
+ if (
+ self.allow_flashinfer
+ and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
+ # Prepare static weights for TRT-LLM kernel
+ # alternate: prepare_static_weight_layouts_for_trtllm_moe
+ (
+ gemm1_weights_fp4_shuffled,
+ gemm1_scales_fp4_shuffled,
+ gemm2_weights_fp4_shuffled,
+ gemm2_scales_fp4_shuffled,
+ ) = prepare_static_weights_for_trtllm_fp4_moe(
+ layer.w13_weight,
+ layer.w2_weight,
+ layer.w13_weight_scale,
+ layer.w2_weight_scale,
+ layer.w2_weight.size(-2), # hidden_size
+ layer.w13_weight.size(-2) // 2, # intermediate_size
+ layer.w13_weight.size(0), # num_experts
+ )
+ logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
+
+ layer.gemm1_weights_fp4_shuffled = Parameter(
+ gemm1_weights_fp4_shuffled, requires_grad=False
+ )
+ layer.gemm2_weights_fp4_shuffled = Parameter(
+ gemm2_weights_fp4_shuffled, requires_grad=False
+ )
+ layer.gemm1_scales_fp4_shuffled = Parameter(
+ gemm1_scales_fp4_shuffled, requires_grad=False
+ )
+ layer.gemm2_scales_fp4_shuffled = Parameter(
+ gemm2_scales_fp4_shuffled, requires_grad=False
+ )
+
+ # Additional parameter needed for TRT-LLM
+ layer.g1_scale_c = Parameter(
+ (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
+ requires_grad=False,
+ )
+
+ # Clean up weights that won't be used by TRT-LLM
+ del layer.w2_weight
+ del layer.w2_weight_scale
+ del layer.w13_weight
+ del layer.w13_weight_scale
+ else:
+ # swizzle weight scales
+ layer.w13_weight_scale = torch.nn.Parameter(
+ swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
+ )
+
+ layer.w2_weight_scale = torch.nn.Parameter(
+ swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
+ )
+
+ def maybe_make_prepare_finalize(
+ self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> mk.FusedMoEPrepareAndFinalize | None:
+ if self.use_marlin or (
+ self.allow_flashinfer
+ and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
return None
elif not self.allow_flashinfer:
- return super().maybe_make_prepare_finalize()
+ return super().maybe_make_prepare_finalize(routing_tables)
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
@@ -408,7 +494,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if self.use_marlin:
+ if (
+ self.use_marlin
+ or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
return None
return nvfp4_moe_quant_config(
@@ -422,7 +511,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -443,25 +532,32 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
- )
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ if (
+ self.allow_flashinfer
+ and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
+ )
+
+ return flashinfer_trtllm_fp4_moe(
+ layer=layer,
+ x=x,
+ router_logits=router_logits,
+ top_k=top_k,
+ global_num_experts=global_num_experts,
+ num_expert_group=num_expert_group,
+ topk_group=topk_group,
+ custom_routing_function=custom_routing_function,
+ e_score_correction_bias=e_score_correction_bias,
+ )
+
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@@ -890,11 +986,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale
)
- def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
+ def maybe_make_prepare_finalize(
+ self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
- return super().maybe_make_prepare_finalize()
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -1001,7 +1100,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1022,31 +1121,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- assert expert_load_view is not None
- assert logical_to_physical_map is not None
- assert logical_replica_count is not None
- assert isinstance(layer, FusedMoE)
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- num_fused_shared_experts=layer.num_fused_shared_experts,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
@@ -1269,7 +1346,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1290,26 +1367,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
@@ -1630,7 +1692,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1651,26 +1713,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
- )
-
assert activation == "silu", f"{activation} not supported for Marlin MoE."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
@@ -1893,7 +1940,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1914,26 +1961,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
@@ -1950,6 +1982,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
quant_config=self.moe_quant_config,
)
+ @property
+ def supports_eplb(self) -> bool:
+ return True
+
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index 5241f9a2301be..7ebe40ec84687 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -137,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -158,26 +158,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ExpertsInt8MoEMethod` yet."
- )
-
from vllm.model_executor.layers.fused_moe import fused_experts
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 0479bec338408..e033032903e87 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
@@ -118,7 +119,9 @@ class Fp8MoeBackend(Enum):
TRITON = 6
-def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
+def get_fp8_moe_backend(
+ block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
+) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
@@ -159,12 +162,25 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
- # deepGEMM on supported platforms with block-quantized weights
- if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
+ # Determine if we should use DeepGEMM with block-quantized weights:
+ # - If explicitly set by user, respect their choice
+ # - If not explicitly set (default), disable when TP size is >= 8
+ moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
+ if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
+ moe_use_deep_gemm = False
+ logger.info_once(
+ "DeepGEMM MoE is disabled by default when TP size is >= 8. "
+ "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
+ scope="local",
+ )
+
+ if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
if not has_deep_gemm():
- logger.warning_once("DeepGEMM backend requested but not available.")
+ logger.warning_once(
+ "DeepGEMM backend requested but not available.", scope="local"
+ )
elif is_deep_gemm_supported():
- logger.info_once("Using DeepGEMM backend for FP8 MoE")
+ logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
@@ -173,7 +189,9 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
and current_platform.is_device_capability(100)
and block_quant
):
- logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
+ logger.info_once(
+ "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local"
+ )
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
@@ -637,7 +655,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
- self.fp8_backend = get_fp8_moe_backend(self.block_quant)
+ self.fp8_backend = get_fp8_moe_backend(
+ self.block_quant, layer.moe_parallel_config
+ )
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
@@ -1018,7 +1038,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
- def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
+ def maybe_make_prepare_finalize(
+ self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
or self.use_marlin
@@ -1039,7 +1062,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
- return super().maybe_make_prepare_finalize()
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -1133,7 +1156,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1209,31 +1232,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
)
- zero_expert_num = getattr(layer, "zero_expert_num", 0)
- zero_expert_type = getattr(layer, "zero_expert_type", None)
-
- select_result = FusedMoE.select_experts(
+ select_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
- enable_eplb=enable_eplb,
- expert_map=expert_map,
- expert_load_view=expert_load_view,
- logical_to_physical_map=logical_to_physical_map,
- logical_replica_count=logical_replica_count,
- global_num_experts=global_num_experts,
- zero_expert_num=zero_expert_num,
- zero_expert_type=zero_expert_type,
- num_fused_shared_experts=layer.num_fused_shared_experts,
)
topk_weights, topk_ids, zero_expert_result = select_result
@@ -1315,7 +1316,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.allow_cutlass_block_scaled_grouped_gemm
),
)
- if zero_expert_num != 0 and zero_expert_type is not None:
+
+ if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index 42d7a67371ae8..bcdfafb50fc5a 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -621,7 +621,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -642,9 +642,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
-
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
@@ -652,19 +649,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_moe_gguf(
x,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 68a122fd46c6b..77b15db373a3a 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -722,7 +722,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -743,26 +743,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `GPTQMarlinMoEMethod` yet."
- )
-
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 476521813f464..8165673135910 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
+from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional
import torch
@@ -13,9 +14,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
- FusedMoEConfig,
FusedMoEQuantConfig,
- RoutingMethodType,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
@@ -38,6 +37,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
+ flashinfer_trtllm_fp4_moe,
+ prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
)
@@ -86,45 +87,218 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
-class ModelOptFp8Config(QuantizationConfig):
+class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
+ """
+ Supports loading kv-cache scaling factors from FP8 checkpoints.
+ """
+
+ def __init__(self, quant_config: "ModelOptQuantConfigBase"):
+ super().__init__(quant_config)
+
+
+class ModelOptQuantConfigBase(QuantizationConfig):
+ LinearMethodCls: type = LinearMethodBase
+ FusedMoEMethodCls: type = FusedMoEMethodBase
+ KVCacheMethodCls: type = BaseKVCacheMethod
+
+ def __init__(
+ self,
+ exclude_modules: list[str],
+ ):
+ super().__init__()
+ self.exclude_modules: list[str] = exclude_modules
+
+ def is_layer_excluded(self, prefix: str) -> bool:
+ """
+ Check if a layer should be excluded from quantization.
+
+ Handles both exact matching (for fused layers) and ModelOpt wildcard matching.
+
+ The ModelOpt exclude_modules list is a list of wildcards.
+ """
+ if len(self.exclude_modules) == 0:
+ return False
+
+ # First check exact matching with fused layer support
+ if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
+ return True
+
+ # TODO: This special hard coded logic is not needed for quantized checkpoints
+ # generated by ModelOpt >= 0.39.0 where they are handled natually by the
+ # exclude_modules config. But need to keep them for loading quantized
+ # checkpoints generated by older versions. Then check substring matching
+ # for patterns not caught by exact match
+ for exclude_module in self.exclude_modules:
+ # Skip exact matches already handled above
+ if exclude_module != prefix and (
+ exclude_module in prefix
+ or (
+ prefix.startswith("language_model.")
+ and exclude_module in prefix.removeprefix("language_model.")
+ )
+ ):
+ return True
+
+ # modelopt exclude modules are not simple strings, they are wildcards
+ for wildcard_pattern in self.exclude_modules:
+ if fnmatch(prefix, wildcard_pattern):
+ return True
+
+ return False
+
+ def get_quant_method(
+ self, layer: torch.nn.Module, prefix: str
+ ) -> Optional["QuantizeMethodBase"]:
+ from vllm.attention.layer import Attention # Avoid circular import
+
+ # handle kv-cache first so we can focus only on weight quantization thereafter
+ if isinstance(layer, Attention):
+ return self.KVCacheMethodCls(self)
+
+ # handle exclusion
+ if self.is_layer_excluded(prefix):
+ if isinstance(layer, LinearBase):
+ return UnquantizedLinearMethod()
+ return None
+
+ # TODO: This special hard coded logic is not needed for quantized checkpoints
+ # generated by ModelOpt >= 0.39.0 where they are handled natually by the
+ # exclude_modules config. But need to keep them for loading quantized
+ # checkpoints generated by older versions. Then check substring matching
+ # for patterns not caught by exact match
+ if "vision_tower" in prefix or "vision_model" in prefix:
+ return UnquantizedLinearMethod()
+
+ # now, the layer is quantized, handle it here
+ if isinstance(layer, LinearBase):
+ return self.LinearMethodCls(self)
+ elif isinstance(layer, FusedMoE):
+ return self.FusedMoEMethodCls(quant_config=self, layer=layer)
+
+ return None
+
+ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
+ if len(self.exclude_modules) > 0:
+ self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
+
+ @staticmethod
+ def get_config_filenames() -> list[str]:
+ return ["hf_quant_config.json"]
+
+ @classmethod
+ def _from_config(
+ cls,
+ *,
+ quant_method: str,
+ kv_cache_quant_method: str | None,
+ exclude_modules: list[str],
+ original_config: dict[str, Any],
+ group_size: int | None,
+ ) -> "ModelOptQuantConfigBase":
+ raise NotImplementedError("Please implement this function in sub classes")
+
+ @classmethod
+ def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
+ # Handle both ModelOpt format and compressed-tensors style format
+ if "quantization" in config:
+ # Traditional ModelOpt format:
+ # {"quantization": {"quant_algo": "..."}}
+ quant_config = cls.get_from_keys(config, ["quantization"])
+ if not isinstance(quant_config, dict):
+ raise ValueError("Expected 'quantization' to be a dictionary in config")
+
+ quant_method = quant_config.get("quant_algo")
+
+ # Handle kv_cache_quant_algo with proper type validation
+ kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
+
+ # Handle group_size with proper type validation
+ group_size_raw = quant_config.get("group_size")
+
+ # "exclude_modules" is the key in the legacy hf_quant_config.json
+ exclude_modules = quant_config.get("exclude_modules", [])
+ else:
+ # Compressed-tensors style format:
+ # {"quant_algo": "...", "quant_method": "modelopt"}
+ quant_method = config.get("quant_algo")
+ kv_cache_quant_method = config.get("kv_cache_quant_algo")
+ # "ignore" is the key in config.json
+ exclude_modules = config.get("ignore", [])
+ group_size_raw = config.get("group_size")
+
+ if not quant_method:
+ raise ValueError("Missing 'quant_algo' in quantization config")
+
+ if kv_cache_quant_method is None:
+ # No KV cache quantization, keep this branch just to have this comment
+ pass
+ elif not isinstance(kv_cache_quant_method, str):
+ raise ValueError(
+ f"kv_cache_quant_algo must be a string, got "
+ f"{type(kv_cache_quant_method)}"
+ )
+
+ if not isinstance(exclude_modules, list):
+ raise ValueError(
+ f"exclude_modules must be a list, got {type(exclude_modules)}"
+ )
+
+ if group_size_raw is None:
+ group_size = None
+ elif isinstance(group_size_raw, int):
+ group_size = group_size_raw
+ else:
+ try:
+ group_size = int(group_size_raw)
+ except (ValueError, TypeError):
+ raise ValueError(
+ f"group_size must be an integer, got {type(group_size_raw)}"
+ ) from None
+
+ if quant_method not in QUANT_ALGOS:
+ raise ValueError(
+ f"ModelOpt currently only supports: {QUANT_ALGOS} "
+ "quantizations in vLLM. Please check the "
+ "`hf_quant_config.json` file for your model's "
+ "quant configuration."
+ )
+ return cls._from_config(
+ quant_method=quant_method,
+ kv_cache_quant_method=kv_cache_quant_method,
+ exclude_modules=exclude_modules,
+ group_size=group_size,
+ original_config=config,
+ )
+
+
+class ModelOptFp8Config(ModelOptQuantConfigBase):
"""Config class for ModelOpt FP8."""
def __init__(
self,
- is_checkpoint_fp8_serialized: bool = False,
- kv_cache_quant_method: str | None = None,
- exclude_modules: list[str] | None = None,
+ is_checkpoint_fp8_serialized: bool,
+ kv_cache_quant_method: str | None,
+ exclude_modules: list[str],
) -> None:
- super().__init__()
+ super().__init__(exclude_modules)
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
- self.exclude_modules = exclude_modules or []
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
)
- @classmethod
- def get_name(cls) -> QuantizationMethods:
+ def get_name(self) -> QuantizationMethods:
return "modelopt"
- @classmethod
- def get_supported_act_dtypes(cls) -> list[torch.dtype]:
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
- @classmethod
- def get_config_filenames(cls) -> list[str]:
- return ["hf_quant_config.json"]
-
- def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
- if self.exclude_modules is not None:
- self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
-
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
@@ -158,88 +332,19 @@ class ModelOptFp8Config(QuantizationConfig):
return None
@classmethod
- def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
- # Handle both ModelOpt format and compressed-tensors style format
- if "quantization" in config:
- # ModelOpt format: {"quantization": {"quant_algo": "..."}}
- quant_config = cls.get_from_keys(config, ["quantization"])
- if not isinstance(quant_config, dict):
- raise ValueError("Expected 'quantization' to be a dictionary in config")
- quant_method = quant_config.get("quant_algo", "")
- if not quant_method:
- raise ValueError("Missing 'quant_algo' in quantization config")
- kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
- # "exclude_modules" is the key in the legacy hf_quant_config.json
- exclude_modules = quant_config.get("exclude_modules")
- else:
- # Compressed-tensors style format:
- # {"quant_algo": "...", "quant_method": "modelopt"}
- quant_method = config.get("quant_algo", "")
- kv_cache_quant_method = config.get("kv_cache_quant_algo")
- # "ignore" is the key in config.json
- exclude_modules = config.get("ignore")
-
- if quant_method not in QUANT_ALGOS:
- raise ValueError(
- f"ModelOpt currently only supports: {QUANT_ALGOS} "
- "quantizations in vLLM. Please check the "
- "`hf_quant_config.json` file for your model's "
- "quant configuration."
- )
+ def _from_config(
+ cls,
+ *,
+ quant_method: str,
+ kv_cache_quant_method: str | None,
+ exclude_modules: list[str],
+ original_config: dict[str, Any],
+ **kwargs: Any,
+ ) -> "ModelOptFp8Config":
is_checkpoint_fp8_serialized = "FP8" in quant_method
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
- def is_layer_excluded(self, prefix: str) -> bool:
- """
- Check if a layer should be excluded from quantization.
- Handles both exact matching (for fused layers) and substring matching.
-
- This method handles both regular models and multimodal models that use
- the language_model prefix. For multimodal models, it checks if the
- module name (without the language_model prefix) is in the exclude list.
- """
- if self.exclude_modules is None:
- return False
-
- # First check exact matching with fused layer support
- if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
- return True
-
- # Then check substring matching for patterns not caught by exact match
- for module in self.exclude_modules:
- # Skip exact matches already handled above
- if module != prefix and (
- module in prefix
- or (
- prefix.startswith("language_model.")
- and module in prefix.removeprefix("language_model.")
- )
- ):
- return True
- return False
-
- def get_quant_method(
- self, layer: torch.nn.Module, prefix: str
- ) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import ( # Avoid circular import
- Attention,
- MLAAttention,
- )
-
- if isinstance(layer, LinearBase):
- if self.is_layer_excluded(prefix):
- return UnquantizedLinearMethod()
- # Check if this is a vision model layer that should not be quantized
- if "vision_tower" in prefix or "vision_model" in prefix:
- return UnquantizedLinearMethod()
- return ModelOptFp8LinearMethod(self)
- elif isinstance(layer, (Attention, MLAAttention)):
- return ModelOptFp8KVCacheMethod(self)
- elif isinstance(layer, FusedMoE):
- return ModelOptFp8MoEMethod(self, layer)
- return None
-
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
@@ -344,7 +449,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptFp8Config,
- layer: torch.nn.Module,
+ layer: FusedMoE,
) -> None:
super().__init__(layer.moe_config)
self.layer = layer
@@ -373,6 +478,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
@@ -384,7 +490,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
- return super().maybe_make_prepare_finalize()
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -590,7 +696,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -611,12 +717,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ModelOptFp8MoEMethod` yet."
- )
-
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptFp8MoEMethod` yet."
+ )
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
@@ -634,19 +739,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
# Expert selection
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
@@ -685,7 +780,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
-class ModelOptNvFp4Config(QuantizationConfig):
+ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
+ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
+ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
+
+
+class ModelOptNvFp4Config(ModelOptQuantConfigBase):
"""Config class for ModelOpt FP4."""
def __init__(
@@ -695,7 +795,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
exclude_modules: list[str],
group_size: int = 16,
) -> None:
- super().__init__()
+ super().__init__(exclude_modules)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
@@ -705,28 +805,17 @@ class ModelOptNvFp4Config(QuantizationConfig):
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
- self.exclude_modules = exclude_modules
- @classmethod
- def get_name(cls) -> QuantizationMethods:
+ def get_name(self) -> QuantizationMethods:
return "modelopt_fp4"
- @classmethod
- def get_supported_act_dtypes(cls) -> list[torch.dtype]:
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 80
- @classmethod
- def get_config_filenames(cls) -> list[str]:
- return ["hf_quant_config.json"]
-
- def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
- if self.exclude_modules is not None:
- self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
-
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
@@ -760,105 +849,25 @@ class ModelOptNvFp4Config(QuantizationConfig):
return None
@classmethod
- def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
- # Handle both traditional ModelOpt format and compressed-tensors
- # style format
- if "quantization" in config:
- # Traditional ModelOpt format:
- # {"quantization": {"quant_algo": "..."}}
- quant_config = cls.get_from_keys(config, ["quantization"])
- if not isinstance(quant_config, dict):
- raise ValueError("Expected 'quantization' to be a dictionary in config")
-
- quant_method = quant_config.get("quant_algo", "")
- if not quant_method:
- raise ValueError("Missing 'quant_algo' in quantization config")
-
- # Handle kv_cache_quant_algo with proper type validation
- kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
- if kv_cache_quant_algo_raw is None:
- # No KV cache quantization by default
- kv_cache_quant_algo = None
- elif isinstance(kv_cache_quant_algo_raw, str):
- kv_cache_quant_algo = kv_cache_quant_algo_raw
- else:
- raise ValueError(
- f"kv_cache_quant_algo must be a string, got "
- f"{type(kv_cache_quant_algo_raw)}"
- )
-
- # Handle group_size with proper type validation
- group_size_raw = quant_config.get("group_size")
- if group_size_raw is None:
- group_size = 16 # Default value
- elif isinstance(group_size_raw, int):
- group_size = group_size_raw
- else:
- try:
- group_size = int(group_size_raw)
- except (ValueError, TypeError):
- raise ValueError(
- f"group_size must be an integer, got {type(group_size_raw)}"
- ) from None
-
- # "exclude_modules" is the key in the legacy hf_quant_config.json
- exclude_modules = quant_config.get("exclude_modules", [])
- if not isinstance(exclude_modules, list):
- raise ValueError(
- f"exclude_modules must be a list, got {type(exclude_modules)}"
- )
- else:
- # Compressed-tensors style format:
- # {"quant_algo": "...", "quant_method": "modelopt"}
- quant_method = config.get("quant_algo", "")
-
- # Handle kv_cache_quant_algo with proper type validation
- kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
- if kv_cache_quant_algo_raw is None:
- # No KV cache quantization by default
- kv_cache_quant_algo = None
- elif isinstance(kv_cache_quant_algo_raw, str):
- kv_cache_quant_algo = kv_cache_quant_algo_raw
- else:
- raise ValueError(
- f"kv_cache_quant_algo must be a string, got "
- f"{type(kv_cache_quant_algo_raw)}"
- )
-
- # Handle group_size with proper type validation
- group_size_raw = config.get("group_size")
- if group_size_raw is None:
- group_size = 16 # Default value
- elif isinstance(group_size_raw, int):
- group_size = group_size_raw
- else:
- try:
- group_size = int(group_size_raw)
- except (ValueError, TypeError):
- raise ValueError(
- f"group_size must be an integer, got {type(group_size_raw)}"
- ) from None
-
- # "ignore" is the key in config.json
- exclude_modules = config.get("ignore", [])
- if not isinstance(exclude_modules, list):
- raise ValueError(
- f"exclude_modules must be a list, got {type(exclude_modules)}"
- )
-
- if quant_method not in QUANT_ALGOS:
- raise ValueError(
- f"ModelOpt currently only supports: {QUANT_ALGOS} "
- "quantizations in vLLM. Please check the "
- "`hf_quant_config.json` file for your model's "
- "quant configuration."
- )
+ def _from_config(
+ cls,
+ *,
+ quant_method: str,
+ kv_cache_quant_method: str | None,
+ exclude_modules: list[str],
+ original_config: dict[str, Any],
+ group_size: int | None,
+ **kwargs: Any,
+ ) -> "ModelOptNvFp4Config":
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
+ if group_size is None:
+ group_size = 16 # Default value
+
# For FP4, these fields are required
- if is_checkpoint_nvfp4_serialized and "quantization" in config:
+ if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
# Check if required fields are present in the quantization config
- quant_config = config["quantization"]
+ quant_config = original_config["quantization"]
required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
missing_fields = [
field for field in required_fields if field not in quant_config
@@ -871,64 +880,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(
is_checkpoint_nvfp4_serialized,
- kv_cache_quant_algo,
+ kv_cache_quant_method,
exclude_modules,
group_size,
)
- def is_layer_excluded(self, prefix: str) -> bool:
- """
- Check if a layer should be excluded from quantization.
- Handles both exact matching (for fused layers) and pattern matching.
- """
- # First check exact matching with fused layer support
- if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
- return True
-
- # Check regex pattern matching for patterns not caught by exact match
- import regex as re
-
- for pattern in self.exclude_modules:
- # Skip patterns that would be caught by exact matching
- if "*" in pattern or "." in pattern:
- regex_str = pattern.replace(".", r"\.").replace("*", r".*")
- if re.fullmatch(regex_str, prefix):
- return True
- return False
-
- def get_quant_method(
- self, layer: torch.nn.Module, prefix: str
- ) -> Optional["QuantizeMethodBase"]:
- from vllm.attention.layer import ( # Avoid circular import
- Attention,
- MLAAttention,
- )
-
- skip_layer = self.is_layer_excluded(prefix)
- if isinstance(layer, LinearBase):
- if skip_layer:
- return UnquantizedLinearMethod()
- # Check if this is a vision model layer that should not be quantized
- if "vision_tower" in prefix or "vision_model" in prefix:
- return UnquantizedLinearMethod()
- return ModelOptNvFp4LinearMethod(self)
- elif isinstance(layer, (Attention, MLAAttention)):
- return ModelOptFp8KVCacheMethod(self)
- elif isinstance(layer, FusedMoE):
- if skip_layer:
- return None
- return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
- return None
-
-
-class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
- """
- Supports loading kv-cache scaling factors from FP8 checkpoints.
- """
-
- def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
- super().__init__(quant_config)
-
class ModelOptNvFp4LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer NVFP4.
@@ -1156,14 +1112,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptNvFp4Config,
- moe: FusedMoEConfig,
- layer: torch.nn.Module,
+ layer: FusedMoE,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
detect_nvfp4_moe_support, # noqa: E501
)
- super().__init__(moe)
+ super().__init__(layer.moe_config)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
@@ -1171,7 +1126,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.flashinfer_moe_backend = None
- self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
@@ -1179,7 +1133,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" for ModelOptNvFp4FusedMoE."
)
- def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
+ def maybe_make_prepare_finalize(
+ self,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
@@ -1196,7 +1153,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
- return super().maybe_make_prepare_finalize()
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -1335,136 +1292,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
layer.register_parameter("w2_input_scale", w2_input_scale)
- def prepare_static_weights_for_trtllm_fp4_moe(
- self,
- # args_dequant,
- # args,
- gemm1_weights,
- gemm2_weights,
- gemm1_scales_linear_fp4_bytes,
- gemm2_scales_linear_fp4_bytes,
- hidden_size,
- intermediate_size,
- num_experts,
- ):
- from flashinfer import nvfp4_block_scale_interleave
- from flashinfer.fused_moe.core import (
- _maybe_get_cached_w3_w1_permute_indices,
- get_w2_permute_indices_with_cache,
- )
-
- """Prepare quantized weights for kernel (done offline with weights)."""
- epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
-
- # Convert quantized weights to proper formats
- gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
- num_experts, 2 * intermediate_size, hidden_size // 2
- ) # packed fp4
- gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
- torch.float8_e4m3fn
- ).reshape(
- num_experts, 2 * intermediate_size, hidden_size // 16
- ) # fp8 scaling factors
-
- gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
- num_experts, hidden_size, intermediate_size // 2
- ) # packed fp4
- gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
- torch.float8_e4m3fn
- ).reshape(
- num_experts, hidden_size, intermediate_size // 16
- ) # fp8 scaling factors
-
- gemm1_weights_fp4_shuffled = []
- gemm1_scales_fp4_shuffled = []
- gemm2_weights_fp4_shuffled = []
- gemm2_scales_fp4_shuffled = []
- for i in range(num_experts):
- # Calculate the permute indices for the following:
- # 1. Reorder rows of W1 and scales for fused gated activation
- # 2. Shuffle weights and scaling factors for transposed mma output
- # for both w3_w1 and w2 weights and scale factors
- permute_indices = _maybe_get_cached_w3_w1_permute_indices(
- self._cache_permute_indices,
- gemm1_weights_fp4[i].view(torch.uint8),
- epilogue_tile_m,
- )
- gemm1_weights_fp4_shuffled.append(
- gemm1_weights_fp4[i]
- .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
- .contiguous()
- )
-
- permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
- self._cache_permute_indices,
- gemm1_scales_linear_fp4[i].view(torch.uint8),
- epilogue_tile_m,
- num_elts_per_sf=16,
- )
- gemm1_scales_fp4_shuffled.append(
- nvfp4_block_scale_interleave(
- gemm1_scales_linear_fp4[i]
- .view(torch.uint8)[
- permute_sf_indices.to(gemm1_scales_linear_fp4.device)
- ]
- .contiguous()
- )
- )
-
- permute_indices = get_w2_permute_indices_with_cache(
- self._cache_permute_indices,
- gemm2_weights_fp4[i].view(torch.uint8),
- epilogue_tile_m,
- )
- gemm2_weights_fp4_shuffled.append(
- gemm2_weights_fp4[i]
- .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
- .contiguous()
- )
-
- permute_sf_indices = get_w2_permute_indices_with_cache(
- self._cache_permute_indices,
- gemm2_scales_linear_fp4[i].view(torch.uint8),
- epilogue_tile_m,
- num_elts_per_sf=16,
- )
- gemm2_scales_fp4_shuffled.append(
- nvfp4_block_scale_interleave(
- gemm2_scales_linear_fp4[i]
- .view(torch.uint8)[
- permute_sf_indices.to(gemm2_scales_linear_fp4.device)
- ]
- .contiguous()
- )
- )
-
- # Stack weights for all experts
- gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
- gemm1_scales_fp4_shuffled = (
- torch.stack(gemm1_scales_fp4_shuffled)
- .view(torch.float8_e4m3fn)
- .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
- )
-
- gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
- gemm2_scales_fp4_shuffled = (
- torch.stack(gemm2_scales_fp4_shuffled)
- .view(torch.float8_e4m3fn)
- .reshape(num_experts, hidden_size, intermediate_size // 16)
- )
- return (
- gemm1_weights_fp4_shuffled,
- gemm1_scales_fp4_shuffled,
- gemm2_weights_fp4_shuffled,
- gemm2_scales_fp4_shuffled,
- )
-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1 processing
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
- if self.allow_flashinfer:
+ if self.allow_flashinfer and (
+ self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
+ or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ ):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
)
@@ -1537,7 +1373,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
- ) = self.prepare_static_weights_for_trtllm_fp4_moe(
+ ) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
@@ -1612,7 +1448,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -1633,92 +1469,31 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
- )
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
- import flashinfer
-
- from vllm.model_executor.models.llama4 import Llama4MoE
-
- a1_gscale = layer.w13_input_scale_quant
- (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
- flashinfer.fp4_quantize(
- x,
- a1_gscale,
- is_sf_swizzled_layout=False,
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
- )
- use_llama4_routing = (
- custom_routing_function is Llama4MoE.custom_routing_function
- )
- routing_method_type = layer.routing_method_type
- if use_llama4_routing:
- routing_method_type = RoutingMethodType.Llama4
- router_logits = (
- router_logits.to(torch.float32)
- if routing_method_type == RoutingMethodType.DeepSeekV3
- else router_logits
- )
- routing_bias = e_score_correction_bias
- if routing_bias is not None:
- routing_bias = routing_bias.to(torch.bfloat16)
- out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
- routing_logits=router_logits,
- routing_bias=routing_bias,
- hidden_states=hidden_states_fp4,
- hidden_states_scale=hidden_states_scale_linear_fp4.view(
- torch.float8_e4m3fn
- ).flatten(),
- gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
- gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
- torch.float8_e4m3fn
- ),
- gemm1_bias=None,
- gemm1_alpha=None,
- gemm1_beta=None,
- gemm1_clamp_limit=None,
- gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
- gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
- torch.float8_e4m3fn
- ),
- gemm2_bias=None,
- output1_scale_scalar=layer.g1_scale_c.data,
- output1_scale_gate_scalar=layer.g1_alphas.data,
- output2_scale_scalar=layer.g2_alphas.data,
- num_experts=global_num_experts,
+ return flashinfer_trtllm_fp4_moe(
+ layer=layer,
+ x=x,
+ router_logits=router_logits,
top_k=top_k,
- n_group=num_expert_group,
+ global_num_experts=global_num_experts,
+ num_expert_group=num_expert_group,
topk_group=topk_group,
- intermediate_size=layer.intermediate_size_per_partition,
- local_expert_offset=layer.ep_rank * layer.local_num_experts,
- local_num_experts=layer.local_num_experts,
- routed_scaling_factor=None,
- tile_tokens_dim=None,
- routing_method_type=routing_method_type,
- do_finalize=True,
- )[0]
- return out
+ custom_routing_function=custom_routing_function,
+ e_score_correction_bias=e_score_correction_bias,
+ )
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@@ -1742,17 +1517,26 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
workspace=layer.workspace,
)
- elif (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
- ):
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
- flashinfer_cutlass_moe_fp4,
+ elif self.allow_flashinfer:
+ assert self.flashinfer_moe_backend in (
+ FlashinferMoeBackend.CUTLASS,
+ FlashinferMoeBackend.CUTEDSL,
)
+ if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
+ flashinfer_cutlass_moe_fp4,
+ )
+
+ flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
+ else:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
+ flashinfer_cutedsl_moe_fp4,
+ )
+
+ flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
assert self.moe_quant_config is not None
-
- return flashinfer_cutlass_moe_fp4(
+ return flashinfer_fn_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1786,3 +1570,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
k=x.shape[1],
e=layer.w13_weight.shape[0],
)
+
+
+ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
+ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
+ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py
index 2090c86f78dc8..cf348290a2716 100644
--- a/vllm/model_executor/layers/quantization/moe_wna16.py
+++ b/vllm/model_executor/layers/quantization/moe_wna16.py
@@ -359,7 +359,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -380,25 +380,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.")
-
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_experts(
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index b95d1a6b3a1f5..198feb03be3e4 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -132,12 +132,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
)
# If FlashInfer is not available, try either Marlin or Triton
- if (
- envs.VLLM_MXFP4_USE_MARLIN
- or current_platform.get_device_capability()[0] < 9
- or not has_triton_kernels()
- or not is_torch_equal_or_newer("2.8.0")
- ):
+ triton_kernels_supported = (
+ has_triton_kernels()
+ and is_torch_equal_or_newer("2.8.0")
+ # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
+ # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
+ # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
+ and (9, 0) <= current_platform.get_device_capability() < (11, 0)
+ )
+ if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
@@ -755,8 +758,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w13_weight = w13_weight
self.w2_weight = w2_weight
- layer.w13_weight = Parameter(w13_weight.storage.data, requires_grad=False)
- layer.w2_weight = Parameter(w2_weight.storage.data, requires_grad=False)
+ del layer.w13_weight
+ del layer.w2_weight
+ layer.w13_weight = w13_weight
+ layer.w2_weight = w2_weight
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@@ -860,7 +865,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -885,18 +890,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
)
return fused_marlin_moe(
@@ -987,17 +983,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- e_score_correction_bias=e_score_correction_bias,
)
# Backend-specific preparation
@@ -1065,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return triton_kernel_moe_forward(
hidden_states=x,
- w1=self.w13_weight,
- w2=self.w2_weight,
+ w1=layer.w13_weight,
+ w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 30772c3665b06..8be0299eaa66f 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -334,7 +334,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -355,24 +355,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled:
@@ -609,7 +594,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -630,24 +615,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
- )
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
if not self.emulate:
diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
index 007e78e68d5cd..33e9f9806b27e 100644
--- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
+++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F
from vllm import envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4,
@@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
try:
from aiter.ops.shuffle import shuffle_weight
- from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
+ from aiter.ops.triton.gemm_afp4wfp4 import (
+ gemm_afp4wfp4,
+ gemm_afp4wfp4_preshuffled_weight_scales,
+ )
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils.torch_utils import direct_register_custom_op
@@ -66,23 +70,56 @@ try:
x_scales: torch.Tensor | None = None,
) -> torch.Tensor:
M = x.shape[0]
+ N = weight.shape[0]
+ K = weight.shape[1]
if rocm_use_aiter_fp4_asm_gemm:
- if x_scales is None:
- # use hip quant kernel for performance
- x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
+ if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
+ if x_scales is None:
+ # use hip quant kernel for performance
+ if M >= 32:
+ x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
+ else:
+ x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
+ else:
+ x_q = x
+ x_s = x_scales
+
+ if M >= 32:
+ x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
+ else:
+ x_s = x_s[:M, ...].view(torch.uint8)
+
+ y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
+ gemm_afp4wfp4_preshuffled_weight_scales(
+ x_q.view(torch.uint8),
+ weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
+ x_s,
+ weight_scale.view(torch.uint8).view(
+ weight_scale.shape[0] // 32, -1
+ ),
+ out_dtype,
+ y,
+ )
else:
- x_q = x
- x_s = x_scales
+ if x_scales is None:
+ # use hip quant kernel for performance
+ x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
+ else:
+ x_q = x
+ x_s = x_scales
- # 32 alignment is enough for dim0 padding of output for
- # gemm_a4w4 kernel
- y = torch.empty(
- (M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype
- )
+ # 32 alignment is enough for dim0 padding of output for
+ # gemm_a4w4 kernel
+ y = torch.empty(
+ (M + 31) // 32 * 32,
+ weight.shape[0],
+ device=x_q.device,
+ dtype=out_dtype,
+ )
- gemm_a4w4(
- x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
- )
+ gemm_a4w4(
+ x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
+ )
return y[:M]
else:
if x_scales is None:
diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py
index 52656263a601b..7b51b828009fc 100644
--- a/vllm/model_executor/layers/quantization/rtn.py
+++ b/vllm/model_executor/layers/quantization/rtn.py
@@ -356,7 +356,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
def apply(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@@ -377,22 +377,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
-
- topk_weights, topk_ids, _ = FusedMoE.select_experts(
+ topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- indices_type=self.topk_indices_dtype,
)
return fused_marlin_moe(
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..6b2c1dc1312bf
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 5
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..b0eaf02a541ad
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..4cd357d5086ca
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 2
+ },
+ "512": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
new file mode 100644
index 0000000000000..ca2179ddf3d2f
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 5
+ },
+ "8": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "16": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 5
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "256": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index fdf330329e20c..eda40657b1e39 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -9,6 +9,10 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
+ RoutingMethodType,
+)
+from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
+ FlashInferCuteDSLExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
@@ -17,10 +21,14 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
-from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
+from vllm.utils.flashinfer import (
+ has_flashinfer_cutedsl_grouped_gemm_nt_masked,
+ has_flashinfer_cutlass_fused_moe,
+)
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
+ "is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
@@ -36,6 +44,16 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
)
+def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
+ """Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
+ return (
+ envs.VLLM_USE_FLASHINFER_MOE_FP4
+ and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
+ and current_platform.is_cuda()
+ and current_platform.is_device_capability(100)
+ )
+
+
def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -72,18 +90,244 @@ def select_nvfp4_gemm_impl(
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer:
- return FlashInferExperts(
- out_dtype=moe.in_dtype,
- quant_config=moe_quant_config,
- ep_rank=moe.moe_parallel_config.ep_rank,
- ep_size=moe.moe_parallel_config.ep_size,
- tp_rank=moe.moe_parallel_config.tp_rank,
- tp_size=moe.moe_parallel_config.tp_size,
- use_dp=moe.moe_parallel_config.dp_size > 1,
- )
+ if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
+ return FlashInferCuteDSLExperts(
+ out_dtype=moe.in_dtype,
+ quant_config=moe_quant_config,
+ )
+ elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
+ return FlashInferExperts(
+ out_dtype=moe.in_dtype,
+ quant_config=moe_quant_config,
+ ep_rank=moe.moe_parallel_config.ep_rank,
+ ep_size=moe.moe_parallel_config.ep_size,
+ tp_rank=moe.moe_parallel_config.tp_rank,
+ tp_size=moe.moe_parallel_config.tp_size,
+ use_dp=moe.moe_parallel_config.dp_size > 1,
+ )
# native cutlass experts currently don't support DP; TP case won't call this
raise ValueError(
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
)
+
+
+def prepare_static_weights_for_trtllm_fp4_moe(
+ # args_dequant,
+ # args,
+ gemm1_weights,
+ gemm2_weights,
+ gemm1_scales_linear_fp4_bytes,
+ gemm2_scales_linear_fp4_bytes,
+ hidden_size,
+ intermediate_size,
+ num_experts,
+):
+ from flashinfer import nvfp4_block_scale_interleave
+ from flashinfer.fused_moe.core import (
+ _maybe_get_cached_w3_w1_permute_indices,
+ get_w2_permute_indices_with_cache,
+ )
+
+ _cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
+ """Prepare quantized weights for kernel (done offline with weights)."""
+ epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
+
+ # Convert quantized weights to proper formats
+ gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
+ num_experts, 2 * intermediate_size, hidden_size // 2
+ ) # packed fp4
+ gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
+ torch.float8_e4m3fn
+ ).reshape(
+ num_experts, 2 * intermediate_size, hidden_size // 16
+ ) # fp8 scaling factors
+
+ gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
+ num_experts, hidden_size, intermediate_size // 2
+ ) # packed fp4
+ gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
+ torch.float8_e4m3fn
+ ).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors
+
+ gemm1_weights_fp4_shuffled = []
+ gemm1_scales_fp4_shuffled = []
+ gemm2_weights_fp4_shuffled = []
+ gemm2_scales_fp4_shuffled = []
+ for i in range(num_experts):
+ # Calculate the permute indices for the following:
+ # 1. Reorder rows of W1 and scales for fused gated activation
+ # 2. Shuffle weights and scaling factors for transposed mma output
+ # for both w3_w1 and w2 weights and scale factors
+ permute_indices = _maybe_get_cached_w3_w1_permute_indices(
+ _cache_permute_indices,
+ gemm1_weights_fp4[i].view(torch.uint8),
+ epilogue_tile_m,
+ )
+ gemm1_weights_fp4_shuffled.append(
+ gemm1_weights_fp4[i]
+ .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
+ .contiguous()
+ )
+
+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
+ _cache_permute_indices,
+ gemm1_scales_linear_fp4[i].view(torch.uint8),
+ epilogue_tile_m,
+ num_elts_per_sf=16,
+ )
+ gemm1_scales_fp4_shuffled.append(
+ nvfp4_block_scale_interleave(
+ gemm1_scales_linear_fp4[i]
+ .view(torch.uint8)[
+ permute_sf_indices.to(gemm1_scales_linear_fp4.device)
+ ]
+ .contiguous()
+ )
+ )
+
+ permute_indices = get_w2_permute_indices_with_cache(
+ _cache_permute_indices,
+ gemm2_weights_fp4[i].view(torch.uint8),
+ epilogue_tile_m,
+ )
+ gemm2_weights_fp4_shuffled.append(
+ gemm2_weights_fp4[i]
+ .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
+ .contiguous()
+ )
+
+ permute_sf_indices = get_w2_permute_indices_with_cache(
+ _cache_permute_indices,
+ gemm2_scales_linear_fp4[i].view(torch.uint8),
+ epilogue_tile_m,
+ num_elts_per_sf=16,
+ )
+ gemm2_scales_fp4_shuffled.append(
+ nvfp4_block_scale_interleave(
+ gemm2_scales_linear_fp4[i]
+ .view(torch.uint8)[
+ permute_sf_indices.to(gemm2_scales_linear_fp4.device)
+ ]
+ .contiguous()
+ )
+ )
+
+ # Stack weights for all experts
+ gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
+ gemm1_scales_fp4_shuffled = (
+ torch.stack(gemm1_scales_fp4_shuffled)
+ .view(torch.float8_e4m3fn)
+ .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
+ )
+
+ gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
+ gemm2_scales_fp4_shuffled = (
+ torch.stack(gemm2_scales_fp4_shuffled)
+ .view(torch.float8_e4m3fn)
+ .reshape(num_experts, hidden_size, intermediate_size // 16)
+ )
+ return (
+ gemm1_weights_fp4_shuffled,
+ gemm1_scales_fp4_shuffled,
+ gemm2_weights_fp4_shuffled,
+ gemm2_scales_fp4_shuffled,
+ )
+
+
+def flashinfer_trtllm_fp4_moe(
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ global_num_experts: int,
+ num_expert_group: int | None,
+ topk_group: int | None,
+ custom_routing_function: object | None,
+ e_score_correction_bias: torch.Tensor | None,
+) -> torch.Tensor:
+ """
+ Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
+
+ Args:
+ layer: The MoE layer with weights and scales
+ x: Input tensor
+ router_logits: Router logits for expert selection
+ top_k: Number of experts to select per token
+ global_num_experts: Total number of experts across all ranks
+ num_expert_group: Number of expert groups (for grouped routing)
+ topk_group: Top-k within each group
+ custom_routing_function: Custom routing function (e.g., Llama4)
+ e_score_correction_bias: Optional routing bias correction
+
+ Returns:
+ Output tensor from the MoE layer
+ """
+ import flashinfer
+
+ from vllm.model_executor.models.llama4 import Llama4MoE
+
+ # Quantize input to FP4
+ a1_gscale = layer.w13_input_scale_quant
+ (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
+ x,
+ a1_gscale,
+ is_sf_swizzled_layout=False,
+ )
+
+ # Determine routing method type
+ use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
+ routing_method_type = layer.routing_method_type
+ if use_llama4_routing:
+ routing_method_type = flashinfer.RoutingMethodType.Llama4
+
+ # Prepare routing bias
+ routing_bias = e_score_correction_bias
+ if routing_bias is not None:
+ routing_bias = routing_bias.to(torch.bfloat16)
+
+ router_logits = (
+ router_logits.to(torch.float32)
+ if routing_method_type == RoutingMethodType.DeepSeekV3
+ else router_logits
+ )
+
+ # Call TRT-LLM FP4 block-scale MoE kernel
+ out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
+ routing_logits=router_logits,
+ routing_bias=routing_bias,
+ hidden_states=hidden_states_fp4,
+ hidden_states_scale=hidden_states_scale_linear_fp4.view(
+ torch.float8_e4m3fn
+ ).flatten(),
+ gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
+ gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
+ torch.float8_e4m3fn
+ ),
+ gemm1_bias=None,
+ gemm1_alpha=None,
+ gemm1_beta=None,
+ gemm1_clamp_limit=None,
+ gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
+ gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
+ torch.float8_e4m3fn
+ ),
+ gemm2_bias=None,
+ output1_scale_scalar=layer.g1_scale_c.data,
+ output1_scale_gate_scalar=layer.g1_alphas.data,
+ output2_scale_scalar=layer.g2_alphas.data,
+ num_experts=global_num_experts,
+ top_k=top_k,
+ n_group=num_expert_group if num_expert_group is not None else 0,
+ topk_group=topk_group if topk_group is not None else 0,
+ intermediate_size=layer.intermediate_size_per_partition,
+ local_expert_offset=layer.ep_rank * layer.local_num_experts,
+ local_num_experts=layer.local_num_experts,
+ routed_scaling_factor=None,
+ tile_tokens_dim=None,
+ routing_method_type=routing_method_type,
+ do_finalize=True,
+ )[0]
+
+ return out
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index f22e17945d1f6..eef7a0896c375 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -25,6 +25,7 @@ logger = init_logger(__name__)
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
+ CUTEDSL = "CUTEDSL"
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
@@ -273,19 +274,31 @@ def flashinfer_cutlass_moe_fp8(
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
- flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
- # Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
- if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
- 90
- ):
- return FlashinferMoeBackend.CUTLASS
- elif flashinfer_moe_backend == "latency":
- return FlashinferMoeBackend.TENSORRT_LLM
+ backend_map = {
+ "throughput": FlashinferMoeBackend.CUTLASS,
+ "latency": FlashinferMoeBackend.TENSORRT_LLM,
+ "masked_gemm": FlashinferMoeBackend.CUTEDSL,
+ }
+
+ flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
+ if flashinfer_moe_backend in backend_map:
+ if (
+ flashinfer_moe_backend == "latency"
+ and not current_platform.is_device_capability(100)
+ ):
+ logger.info_once(
+ "Flashinfer TRTLLM MOE backend is only supported on "
+ "SM100 and later, using CUTLASS backend instead",
+ scope="local",
+ )
+ return FlashinferMoeBackend.CUTLASS
+ return backend_map[flashinfer_moe_backend]
+ elif current_platform.is_device_capability(90):
+ return FlashinferMoeBackend.CUTLASS
- allowed_backends = ["throughput", "latency"]
raise ValueError(
- f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
- f" expected one of {allowed_backends}"
+ f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. "
+ f"Expected one of {list(backend_map.keys())}."
)
diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
index cbc46810a26a6..d0c8b3d1a3093 100644
--- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
@@ -39,15 +39,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
value_layout = StridedLayout
scale_layout = StridedLayout
elif current_platform.is_rocm():
- from triton_kernels.tensor_details.layout import (
- GFX950MXScaleLayout,
- StridedLayout,
- )
-
from vllm.platforms.rocm import on_gfx950
value_layout = StridedLayout
- scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
+ if on_gfx950():
+ from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
+
+ scale_layout = GFX950MXScaleLayout
+ else:
+ scale_layout = StridedLayout
else:
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1
diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
index c3f26cc774118..44c5b027daf4f 100644
--- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
+++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
@@ -5,6 +5,7 @@ from dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
+ is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
@@ -32,7 +33,10 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported = cutlass_fp4_supported()
- allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available()
+ allow_flashinfer = cutlass_supported and (
+ is_flashinfer_fp4_cutlass_moe_available()
+ or is_flashinfer_fp4_cutedsl_moe_available()
+ )
if allow_flashinfer:
_logger.info_once(
diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py
index 56c165f9c041a..0f10bff6ac4f5 100644
--- a/vllm/model_executor/layers/rotary_embedding/__init__.py
+++ b/vllm/model_executor/layers/rotary_embedding/__init__.py
@@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
+from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
@@ -26,23 +27,23 @@ def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
- base: float,
is_neox_style: bool = True,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
- if rope_scaling is not None:
+ if rope_parameters is not None:
# Transforms every value that is a list into a tuple for caching calls
- rope_scaling_tuple = {
- k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
+ rope_parameters_tuple = {
+ k: tuple(v) if isinstance(v, list) else v
+ for k, v in rope_parameters.items()
}
- rope_scaling_args = tuple(rope_scaling_tuple.items())
+ rope_parameters_args = tuple(rope_parameters_tuple.items())
else:
- rope_scaling_args = None
+ rope_parameters_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
@@ -60,15 +61,15 @@ def get_rope(
head_size,
rotary_dim,
max_position,
- base,
is_neox_style,
- rope_scaling_args,
+ rope_parameters_args,
dual_chunk_attention_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
+ base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
@@ -84,18 +85,18 @@ def get_rope(
dtype,
**extra_kwargs,
)
- elif not rope_scaling:
+ elif not rope_parameters:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
- scaling_type = rope_scaling["rope_type"]
+ scaling_type = rope_parameters["rope_type"]
if scaling_type == "llama3":
- scaling_factor = rope_scaling["factor"]
- low_freq_factor = rope_scaling["low_freq_factor"]
- high_freq_factor = rope_scaling["high_freq_factor"]
- original_max_position = rope_scaling["original_max_position_embeddings"]
+ scaling_factor = rope_parameters["factor"]
+ low_freq_factor = rope_parameters["low_freq_factor"]
+ high_freq_factor = rope_parameters["high_freq_factor"]
+ original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
@@ -113,7 +114,7 @@ def get_rope(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
- if "mrope_section" in rope_scaling:
+ if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
@@ -121,8 +122,8 @@ def get_rope(
base,
is_neox_style,
dtype,
- mrope_section=rope_scaling["mrope_section"],
- mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
+ mrope_section=rope_parameters["mrope_section"],
+ mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
@@ -134,7 +135,7 @@ def get_rope(
dtype,
)
elif scaling_type == "linear":
- scaling_factor = rope_scaling["factor"]
+ scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
@@ -145,8 +146,8 @@ def get_rope(
dtype,
)
elif scaling_type == "ntk":
- scaling_factor = rope_scaling["factor"]
- mixed_b = rope_scaling.get("mixed_b", None)
+ scaling_factor = rope_parameters["factor"]
+ mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
@@ -158,8 +159,8 @@ def get_rope(
mixed_b,
)
elif scaling_type == "dynamic":
- if "alpha" in rope_scaling:
- scaling_alpha = rope_scaling["alpha"]
+ if "alpha" in rope_parameters:
+ scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
@@ -169,8 +170,8 @@ def get_rope(
scaling_alpha,
dtype,
)
- elif "factor" in rope_scaling:
- scaling_factor = rope_scaling["factor"]
+ elif "factor" in rope_parameters:
+ scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
@@ -184,12 +185,24 @@ def get_rope(
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
+ elif scaling_type == "xdrope":
+ scaling_alpha = rope_parameters["alpha"]
+ rotary_emb = XDRotaryEmbedding(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ xdrope_section=rope_parameters["xdrope_section"],
+ )
elif scaling_type == "yarn":
- scaling_factor = rope_scaling["factor"]
- original_max_position = rope_scaling["original_max_position_embeddings"]
+ scaling_factor = rope_parameters["factor"]
+ original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
- for k, v in rope_scaling.items()
+ for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
@@ -197,9 +210,10 @@ def get_rope(
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
+ "truncate",
)
}
- if "mrope_section" in rope_scaling:
+ if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
@@ -208,8 +222,8 @@ def get_rope(
base,
is_neox_style,
dtype,
- mrope_section=rope_scaling["mrope_section"],
- mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
+ mrope_section=rope_parameters["mrope_section"],
+ mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
@@ -225,12 +239,12 @@ def get_rope(
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
- scaling_factor = rope_scaling["factor"]
- original_max_position = rope_scaling["original_max_position_embeddings"]
+ scaling_factor = rope_parameters["factor"]
+ original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
- for k, v in rope_scaling.items()
+ for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
@@ -252,12 +266,12 @@ def get_rope(
**extra_kwargs,
)
elif scaling_type == "longrope":
- short_factor = rope_scaling["short_factor"]
- long_factor = rope_scaling["long_factor"]
- original_max_position = rope_scaling["original_max_position_embeddings"]
+ short_factor = rope_parameters["short_factor"]
+ long_factor = rope_parameters["long_factor"]
+ original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
- for k, v in rope_scaling.items()
+ for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py
index 196533b617959..13f8d15cc0f72 100644
--- a/vllm/model_executor/layers/rotary_embedding/common.py
+++ b/vllm/model_executor/layers/rotary_embedding/common.py
@@ -117,13 +117,13 @@ def yarn_find_correction_range(
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
-) -> tuple[int, int]:
- low = math.floor(
- yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
- )
- high = math.ceil(
- yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
- )
+ truncate: bool = True,
+) -> tuple[float | int, float | int]:
+ low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
+ high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
+ if truncate:
+ low = math.floor(low)
+ high = math.ceil(high)
return max(low, 0), min(high, dim - 1) # Clamp values just in case
diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py
new file mode 100644
index 0000000000000..2432273faf195
--- /dev/null
+++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import numpy as np
+import torch
+
+from .common import apply_rotary_emb_dispatch
+from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
+
+
+class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
+ """DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
+
+ Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
+ """
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ is_neox_style: bool,
+ scaling_alpha: float,
+ dtype: torch.dtype,
+ xdrope_section: list[int],
+ ) -> None:
+ self.xdrope_section = xdrope_section
+ super().__init__(
+ head_size,
+ rotary_dim,
+ max_position_embeddings,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """PyTorch-native implementation equivalent to forward().
+
+ Args:
+ positions:
+ [4, num_tokens] (P/W/H/T positions with multimodal inputs)
+ query: [num_tokens, num_heads * head_size]
+ key: [num_tokens, num_kv_heads * head_size]
+ """
+ assert positions.ndim == 2
+ assert key is not None
+
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ cos = torch.cat(
+ [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+ sin = torch.cat(
+ [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., : self.rotary_dim]
+ query_pass = query[..., self.rotary_dim :]
+ query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., : self.rotary_dim]
+ key_pass = key[..., self.rotary_dim :]
+ key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return query, key
+
+ @staticmethod
+ def get_next_input_positions(
+ context_len: int,
+ seq_len: int,
+ xd_sections: int = 4,
+ ) -> list[list[int]]:
+ return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
+
+ @staticmethod
+ def get_next_input_positions_tensor(
+ out: np.ndarray,
+ out_offset: int,
+ context_len: int,
+ num_new_tokens: int,
+ ):
+ values = np.arange(
+ context_len,
+ context_len + num_new_tokens,
+ dtype=out.dtype,
+ )
+ out[:, out_offset : out_offset + num_new_tokens] = values
diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
index ff46ad74b302e..f01ca1e231211 100644
--- a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
+++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py
@@ -28,12 +28,14 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
beta_fast: int = 32,
beta_slow: int = 1,
apply_yarn_scaling: bool = True,
+ truncate: bool = True,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
+ self.truncate = truncate
# Get n-d magnitude scaling corrected for interpolation
self.mscale = (
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
@@ -57,6 +59,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.rotary_dim,
self.base,
self.max_position_embeddings,
+ self.truncate,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py
index 301f2d00bf404..052d2cfc1099e 100644
--- a/vllm/model_executor/model_loader/__init__.py
+++ b/vllm/model_executor/model_loader/__init__.py
@@ -30,6 +30,7 @@ logger = init_logger(__name__)
# if a new load format is added here
LoadFormats = Literal[
"auto",
+ "hf",
"bitsandbytes",
"dummy",
"fastsafetensors",
@@ -45,6 +46,7 @@ LoadFormats = Literal[
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
+ "hf": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py
index b80026741781f..7401a7a0e2dbb 100644
--- a/vllm/model_executor/model_loader/default_loader.py
+++ b/vllm/model_executor/model_loader/default_loader.py
@@ -31,6 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator,
)
from vllm.platforms import current_platform
+from vllm.transformers_utils.config import list_filtered_repo_files
logger = init_logger(__name__)
@@ -96,8 +97,25 @@ class DefaultModelLoader(BaseModelLoader):
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
- # Some quantized models use .pt files for storing the weights.
+
+ # First check for 'auto' format that mistral files format are present.
+ # This is to load mistral models with official format by default.
if load_format == "auto":
+ load_format = (
+ "mistral"
+ if len(
+ list_filtered_repo_files(
+ model_name_or_path=model_name_or_path,
+ allow_patterns=["consolidated*.safetensors"],
+ revision=revision,
+ )
+ )
+ > 0
+ else "hf"
+ )
+
+ # Some quantized models use .pt files for storing the weights.
+ if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True
@@ -279,7 +297,7 @@ class DefaultModelLoader(BaseModelLoader):
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
- and torchao_version_at_least("0.14.0")
+ and torchao_version_at_least("0.15.0")
):
self.load_config.safetensors_load_strategy = "torchao"
diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py
index d94dbd9f06e0b..1538f0c2af655 100644
--- a/vllm/model_executor/model_loader/sharded_state_loader.py
+++ b/vllm/model_executor/model_loader/sharded_state_loader.py
@@ -4,6 +4,7 @@
import collections
import glob
import os
+import time
from collections.abc import Generator
from typing import Any
@@ -132,6 +133,7 @@ class ShardedStateLoader(BaseModelLoader):
f"pre-sharded checkpoints are currently supported!"
)
state_dict = self._filter_subtensors(model.state_dict())
+ counter_before_loading_weights = time.perf_counter()
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
@@ -150,6 +152,12 @@ class ShardedStateLoader(BaseModelLoader):
)
param_data.copy_(tensor)
state_dict.pop(key)
+ counter_after_loading_weights = time.perf_counter()
+ logger.info_once(
+ "Loading weights took %.2f seconds",
+ counter_after_loading_weights - counter_before_loading_weights,
+ scope="local",
+ )
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
index e74434e9d12cb..2021b68b8a60b 100644
--- a/vllm/model_executor/model_loader/utils.py
+++ b/vllm/model_executor/model_loader/utils.py
@@ -19,12 +19,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
-from vllm.model_executor.models.adapters import (
- as_embedding_model,
- as_reward_model,
- as_seq_cls_model,
- try_create_mm_pooling_model_cls,
-)
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
from vllm.utils.platform_utils import is_pin_memory_available
@@ -172,6 +166,13 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
+ from vllm.model_executor.models.adapters import (
+ as_embedding_model,
+ as_reward_model,
+ as_seq_cls_model,
+ try_create_mm_pooling_model_cls,
+ )
+
architectures = getattr(model_config.hf_config, "architectures", [])
model_cls, arch = model_config.registry.resolve_model_cls(
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 89634cbf41241..4572ebe2ea11b 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -595,6 +595,9 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
+ state_dict = {}
+ leftover_state_dict: dict[str, torch.Tensor] = {}
+
for st_file in tqdm(
hf_weights_files,
desc=loading_desc,
@@ -606,9 +609,11 @@ def safetensors_weights_iterator(
state_dict = load(f.read())
yield from state_dict.items()
elif safetensors_load_strategy == "torchao":
- if not torchao_version_at_least("0.14.0"):
+ # we can't load flattened torchao tensor subclasses directly into the model
+ # instead we reconstruct the subclasses here before returning
+ if not torchao_version_at_least("0.15.0"):
raise ValueError(
- "Please use torchao version >= 0.14.0 \
+ "Please use torchao version >= 0.15.0 \
to load torchao safetensors checkpoint"
)
from torchao.prototype.safetensors.safetensors_support import (
@@ -616,12 +621,20 @@ def safetensors_weights_iterator(
)
with safe_open(st_file, framework="pt") as f:
- state_dict = {}
for name in f.keys(): # noqa: SIM118
state_dict[name] = f.get_tensor(name)
+
+ # update with leftover tensor data from previous iteration, if any
+ state_dict.update(leftover_state_dict)
metadata = f.metadata()
- updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
- yield from updated_state_dict.items()
+ # due to sharded checkpoints, we are not guaranteed that we have all
+ # tensor subclass data on one file
+ # state_dict has the leftover data from this step and we wait for
+ # missing information to be provided in a future iteration
+ unflattened_state_dict, leftover_state_dict = (
+ unflatten_tensor_state_dict(state_dict, metadata)
+ )
+ yield from unflattened_state_dict.items()
else:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py
index 6f654f47495f7..4eb5665a71fc8 100644
--- a/vllm/model_executor/models/afmoe.py
+++ b/vllm/model_executor/models/afmoe.py
@@ -5,7 +5,6 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -171,8 +170,6 @@ class AfmoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 131072,
head_dim: int | None = None,
rms_norm_eps: float = 1e-05,
@@ -202,7 +199,6 @@ class AfmoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# Check if this is a local attention layer
@@ -246,8 +242,7 @@ class AfmoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config["rope_parameters"],
is_neox_style=True,
)
else:
@@ -303,14 +298,6 @@ class AfmoeDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
# DecoderLayers are created with `make_layers` which passes the prefix
@@ -323,8 +310,6 @@ class AfmoeDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py
index 0a8f21abb0a35..b75e91319bbad 100644
--- a/vllm/model_executor/models/apertus.py
+++ b/vllm/model_executor/models/apertus.py
@@ -27,7 +27,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -118,8 +117,6 @@ class ApertusAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -155,7 +152,6 @@ class ApertusAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -176,9 +172,7 @@ class ApertusAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
- self._init_rotary_emb(
- config, rope_scaling=rope_scaling, quant_config=quant_config
- )
+ self._init_rotary_emb(config, quant_config=quant_config)
sliding_window = None
if layer_types := getattr(config, "layer_types", None):
@@ -224,7 +218,6 @@ class ApertusAttention(nn.Module):
def _init_rotary_emb(
self,
config: ApertusConfig,
- rope_scaling: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = True
@@ -236,8 +229,7 @@ class ApertusAttention(nn.Module):
self.head_dim,
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
@@ -253,14 +245,6 @@ class ApertusDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -288,8 +272,6 @@ class ApertusDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py
index 20c3ff0754506..b3887b16f4d74 100644
--- a/vllm/model_executor/models/arcee.py
+++ b/vllm/model_executor/models/arcee.py
@@ -103,15 +103,6 @@ class ArceeDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Rotary embedding parameters (reuse LLaMA defaults)
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Determine if attention bias is needed (some variants use bias terms)
attention_bias = getattr(config, "attention_bias", False) or getattr(
@@ -133,8 +124,6 @@ class ArceeDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py
index b5cc07a56535d..b75a254761d4e 100644
--- a/vllm/model_executor/models/arctic.py
+++ b/vllm/model_executor/models/arctic.py
@@ -292,7 +292,6 @@ class ArcticAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
@@ -317,7 +316,7 @@ class ArcticAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=int(self.rope_theta),
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index 8991ef4c606b6..edf47270e5277 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -136,7 +136,7 @@ class BaiChuanAttention(nn.Module):
hidden_size: int,
num_heads: int,
position_embedding: str,
- rope_theta: float = 10000,
+ rope_parameters: dict,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -150,7 +150,6 @@ class BaiChuanAttention(nn.Module):
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads
self.position_embedding = position_embedding
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
@@ -192,7 +191,7 @@ class BaiChuanAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=rope_parameters,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
@@ -229,13 +228,12 @@ class BaiChuanDecoderLayer(nn.Module):
):
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
- rope_theta=rope_theta,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py
index 024425bb24406..cc10e936a2d3d 100644
--- a/vllm/model_executor/models/bailing_moe.py
+++ b/vllm/model_executor/models/bailing_moe.py
@@ -135,9 +135,8 @@ class BailingAttention(nn.Module):
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings,
- base=config.rope_theta,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
- rope_scaling=config.rope_scaling,
partial_rotary_factor=self.partial_rotary_factor,
)
diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py
index c6cc83487fec2..4422bb5da98f4 100644
--- a/vllm/model_executor/models/bamba.py
+++ b/vllm/model_executor/models/bamba.py
@@ -156,8 +156,6 @@ class BambaAttentionDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -178,7 +176,6 @@ class BambaAttentionDecoderLayer(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "partial_rotary_factor"):
@@ -192,8 +189,7 @@ class BambaAttentionDecoderLayer(nn.Module):
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
- rope_scaling=rope_scaling,
- base=rope_theta,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
dtype=torch.get_default_dtype(), # see impl of get_rope
)
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 3c87bbfefab3d..b5a6d00dc309f 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -265,8 +265,7 @@ class ChameleonAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any],
max_position_embeddings: int = 4096,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -293,7 +292,6 @@ class ChameleonAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -318,8 +316,7 @@ class ChameleonAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
@@ -369,14 +366,6 @@ class ChameleonDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention(
@@ -385,8 +374,7 @@ class ChameleonDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
@@ -439,14 +427,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention(
@@ -455,8 +435,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 5d6f5e9125a28..dbfcd62d0bcab 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -99,6 +99,7 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
+ rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope
@@ -106,7 +107,7 @@ class GLMAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
- base=10000 * rope_ratio,
+ rope_parameters=rope_parameters,
is_neox_style=is_neox_style,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index 77bb178519813..5ed920927c772 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -156,8 +156,6 @@ class CohereAttention(nn.Module):
self.max_position_embeddings = getattr(
config, "model_max_length", None
) or getattr(config, "max_position_embeddings", 8192)
- self.rope_theta = config.rope_theta
- self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
@@ -179,8 +177,7 @@ class CohereAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=self.rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=False,
)
diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py
index 66b246878b0aa..3cf4bf991e667 100644
--- a/vllm/model_executor/models/config.py
+++ b/vllm/model_executor/models/config.py
@@ -8,6 +8,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
+from vllm.transformers_utils.config import set_default_rope_theta
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
@@ -46,8 +47,7 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
- "base": config.rope_theta,
- "rope_scaling": getattr(config, "rope_scaling", None),
+ "rope_parameters": config.rope_parameters,
}
@@ -78,12 +78,13 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
if not model_config.enforce_eager:
max_position = round_up(max_position, 8)
+ set_default_rope_theta(config, default_theta=config.rotary_emb_base)
+
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": max_position,
- "base": getattr(config, "rope_theta", config.rotary_emb_base),
- "rope_scaling": getattr(config, "rope_scaling", None),
+ "rope_parameters": config.rope_parameters,
}
@@ -117,18 +118,20 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
max_trained_positions = getattr(config, "max_trained_positions", 2048)
+
+ set_default_rope_theta(config, default_theta=config.rotary_emb_base)
+
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
- "base": getattr(config, "rope_theta", config.rotary_emb_base),
- "rope_scaling": getattr(config, "rope_scaling", None),
+ "rope_parameters": config.rope_parameters,
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
- # The context extension uses vllm style rope_theta and rope_scaling.
+ # The context extension uses vllm style rope_theta and rope_parameters.
# See #17785 #18755
if (
not vllm_config.model_config.hf_overrides
@@ -172,7 +175,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
- hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
+ hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
@@ -246,8 +249,7 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
- "base": config.rope_theta,
- "rope_scaling": getattr(config, "rope_scaling", None),
+ "rope_parameters": config.rope_parameters,
}
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 528ef4f76742d..2c729019081a4 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -197,7 +197,10 @@ class DbrxAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads
self.total_num_kv_heads = config.attn_config.kv_n_heads
self.clip_qkv = config.attn_config.clip_qkv
- self.rope_theta = config.attn_config.rope_theta
+ rope_parameters = {
+ "rope_type": "default",
+ "rope_theta": int(config.attn_config.rope_theta),
+ }
self.max_position = config.max_seq_len
# pylint: disable=invalid-name
@@ -221,7 +224,7 @@ class DbrxAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
- base=int(self.rope_theta),
+ rope_parameters=rope_parameters,
is_neox_style=True,
)
diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py
index 3fb04c3b70dd1..4d7a37292cb02 100644
--- a/vllm/model_executor/models/deepseek_eagle.py
+++ b/vllm/model_executor/models/deepseek_eagle.py
@@ -8,7 +8,6 @@ import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
-from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -172,10 +171,6 @@ class DeepseekV2Model(nn.Module):
)
break
else:
- # if PP disabled then draft will share embed with target
- if get_pp_group().world_size == 1 and "embed_tokens." in name:
- continue
-
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py
index e028dc497aa6a..6e23037b919ab 100644
--- a/vllm/model_executor/models/deepseek_mtp.py
+++ b/vllm/model_executor/models/deepseek_mtp.py
@@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Iterable
+import typing
+from collections.abc import Callable, Iterable
import torch
import torch.nn as nn
from transformers import PretrainedConfig
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -231,6 +233,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ rocm_aiter_moe_shared_expert_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
@@ -238,11 +243,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
- num_experts=self.config.n_routed_experts,
+ num_experts=self.config.n_routed_experts
+ + (
+ self.config.n_shared_experts
+ if rocm_aiter_moe_shared_expert_enabled
+ else 0
+ ),
)
params_dict = dict(self.named_parameters())
@@ -253,6 +263,9 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
+ )
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
@@ -266,6 +279,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
+ if is_fusion_moe_shared_experts_layer:
+ continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
@@ -286,45 +301,105 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight, shard_id)
break
else:
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(
- param,
- loaded_weight,
- name,
- shard_id=shard_id,
- expert_id=expert_id,
+ # Special handling: when AITER fusion_shared_experts is enabled,
+ # checkpoints may provide a single widened shared_experts tensor
+ # without explicit expert indices
+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
+ # For models with multiple shared experts, split that tensor
+ # evenly into per-shared-expert slices and load them into
+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
+ # accordingly.
+ num_chunks = 1
+ if is_fusion_moe_shared_experts_layer:
+ num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
+ # Determine split axis based on op type
+ # gate/up: ColumnParallel → split along dim 0
+ # down: RowParallel → split along dim 1
+ split_dim = 1 if "down_proj.weight" in name else 0
+ total = loaded_weight.shape[split_dim]
+ assert total % num_chunks == 0, (
+ f"Shared expert weight dim {total} "
+ f"not divisible by num_chunks {num_chunks}"
)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
+ chunk_size = total // num_chunks
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
+ for j in range(num_chunks):
+ chunk_name = name
+ weight_to_load = loaded_weight
- # According to DeepSeek-V3 Technical Report, MTP modules
- # shares embedding layer. We only load the first weights.
- if (
- spec_layer != self.model.mtp_start_layer_idx
- and ".layers" not in name
- ):
- continue
+ if is_fusion_moe_shared_experts_layer:
+ if split_dim == 0:
+ weight_to_load = loaded_weight[
+ j * chunk_size : (j + 1) * chunk_size, :
+ ]
+ else:
+ weight_to_load = loaded_weight[
+ :, j * chunk_size : (j + 1) * chunk_size
+ ]
+ # Synthesize an expert-style name so expert mapping
+ # can route it
+ chunk_name = name.replace(
+ "mlp.shared_experts",
+ f"mlp.experts.{self.config.n_routed_experts + j}",
+ )
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
+ # Use expert_params_mapping to locate the destination
+ # param and delegate to its expert-aware weight_loader
+ # with expert_id.
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in chunk_name:
+ continue
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = chunk_name.replace(weight_name, param_name)
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or
+ # not here since otherwise we may skip experts with
+ # other available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ weight_to_load,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ if not is_fusion_moe_shared_experts_layer:
+ name = name_mapped
+ else:
+ loaded_params.add(name_mapped)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ # According to DeepSeek-V3 Technical Report, MTP modules
+ # shares embedding layer. We only load the first weights.
+ if (
+ spec_layer != self.model.mtp_start_layer_idx
+ and ".layers" not in name
+ ):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ if not is_fusion_moe_shared_experts_layer:
+ loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index e8ee9951d6119..ad932559b983d 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -27,7 +27,6 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -111,8 +110,6 @@ class DeepseekAttention(nn.Module):
config: DeepseekV2Config | DeepseekV3Config,
hidden_size: int,
num_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -139,7 +136,6 @@ class DeepseekAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -162,8 +158,7 @@ class DeepseekAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -292,7 +287,10 @@ class DeepseekV2MoE(nn.Module):
)
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
- if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
+ self.is_fusion_moe_shared_experts_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
+ if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@@ -332,7 +330,7 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
- if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ if self.is_fusion_moe_shared_experts_enabled
else None,
)
@@ -409,8 +407,6 @@ class DeepseekV2Attention(nn.Module):
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -430,7 +426,6 @@ class DeepseekV2Attention(nn.Module):
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
assert topk_indices_buffer is None, (
"topk_indices_buffer is not \
@@ -485,21 +480,20 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
- if rope_scaling:
- rope_scaling["rope_type"] = "deepseek_yarn"
+ if config.rope_parameters["rope_type"] != "default":
+ config.rope_parameters["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=False,
)
- if rope_scaling:
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
- scaling_factor = rope_scaling["factor"]
+ if config.rope_parameters["rope_type"] != "default":
+ mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
+ scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
@@ -600,6 +594,7 @@ def sparse_attn_indexer(
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
+ fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
@@ -639,7 +634,7 @@ def sparse_attn_indexer(
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
- dtype=torch.float8_e4m3fn,
+ dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
@@ -653,7 +648,12 @@ def sparse_attn_indexer(
chunk.block_table,
chunk.cu_seq_lens,
)
- logits = fp8_mqa_logits(
+ fp8_mqa_logits_func = fp8_mqa_logits
+ if current_platform.is_rocm():
+ from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
+
+ fp8_mqa_logits_func = rocm_fp8_mqa_logits
+ logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start : chunk.token_end],
@@ -698,7 +698,14 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
- logits = fp8_paged_mqa_logits(
+ fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
+ if current_platform.is_rocm():
+ from vllm.attention.ops.rocm_aiter_mla_sparse import (
+ rocm_fp8_paged_mqa_logits,
+ )
+
+ fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
+ logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
@@ -755,7 +762,8 @@ def sparse_attn_indexer_fake(
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
- _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
+ fp8_dtype = current_platform.fp8_dtype()
+ _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
@@ -846,8 +854,8 @@ class Indexer(nn.Module):
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
- q = torch.cat([q_pe, q_nope], dim=-1)
- k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
+ q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
+ k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
@@ -903,8 +911,6 @@ class DeepseekV2MLAAttention(nn.Module):
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -927,7 +933,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
@@ -981,25 +986,31 @@ class DeepseekV2MLAAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
- if rope_scaling:
- rope_scaling["rope_type"] = "deepseek_yarn"
+ if config.rope_parameters["rope_type"] != "default":
+ config.rope_parameters["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=False,
)
- if rope_scaling:
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
- scaling_factor = rope_scaling["factor"]
+ if config.rope_parameters["rope_type"] != "default":
+ mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
+ scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
+ self.indexer_rope_emb = get_rope(
+ qk_rope_head_dim,
+ rotary_dim=qk_rope_head_dim,
+ max_position=max_position_embeddings,
+ rope_parameters=config.rope_parameters,
+ is_neox_style=True,
+ )
self.indexer = Indexer(
vllm_config,
config,
@@ -1011,6 +1022,7 @@ class DeepseekV2MLAAttention(nn.Module):
f"{prefix}.indexer",
)
else:
+ self.indexer_rope_emb = None
self.indexer = None
mla_modules = MLAModules(
@@ -1028,6 +1040,7 @@ class DeepseekV2MLAAttention(nn.Module):
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else None,
indexer=self.indexer,
+ indexer_rotary_emb=self.indexer_rope_emb,
is_sparse=self.is_v32,
topk_indices_buffer=topk_indices_buffer,
)
@@ -1073,8 +1086,6 @@ class DeepseekV2DecoderLayer(nn.Module):
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
# DecoderLayers are created with `make_layers` which passes the prefix
@@ -1107,8 +1118,6 @@ class DeepseekV2DecoderLayer(nn.Module):
v_head_dim=v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=kv_lora_rank,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
@@ -1470,8 +1479,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
- is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
- "mlp.shared_experts" in name
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1486,7 +1495,7 @@ class DeepseekV2ForCausalLM(
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
@@ -1522,7 +1531,7 @@ class DeepseekV2ForCausalLM(
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
@@ -1539,7 +1548,7 @@ class DeepseekV2ForCausalLM(
chunk_name = name
weight_to_load = loaded_weight
- if is_fuse_shared_experts_layer:
+ if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
@@ -1590,7 +1599,7 @@ class DeepseekV2ForCausalLM(
return_success=True,
)
if success:
- if not is_fuse_shared_experts_layer:
+ if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
@@ -1619,7 +1628,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
- if not is_fuse_shared_experts_layer:
+ if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
return loaded_params
diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py
index d24da0c42a254..e65c275106a4e 100644
--- a/vllm/model_executor/models/dots1.py
+++ b/vllm/model_executor/models/dots1.py
@@ -27,7 +27,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -202,8 +201,6 @@ class Dots1Attention(nn.Module):
num_heads: int,
num_kv_heads: int,
config: Dots1Config,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -229,7 +226,6 @@ class Dots1Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
attention_bias = config.attention_bias
@@ -255,8 +251,7 @@ class Dots1Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -296,8 +291,6 @@ class Dots1DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
layer_idx = int(prefix.split(sep=".")[-1])
self.layer_idx = layer_idx
@@ -307,8 +300,6 @@ class Dots1DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
config=config,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py
index 2d2251e83b5b1..5460018d0d67a 100644
--- a/vllm/model_executor/models/dots_ocr.py
+++ b/vllm/model_executor/models/dots_ocr.py
@@ -306,7 +306,6 @@ class DotsVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -324,7 +323,6 @@ class DotsVisionAttention(nn.Module):
rotary_pos_emb: torch.Tensor | None = None,
*,
max_seqlen: int | None = None,
- seqlens: list[int] | None = None,
) -> torch.Tensor:
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
@@ -374,16 +372,6 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
else:
raise RuntimeError("Unsupported attention backend")
@@ -545,14 +533,12 @@ class DotsVisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None,
- seqlens: list[int] | None = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -663,18 +649,14 @@ class DotsVisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
@@ -694,14 +676,13 @@ class DotsVisionTransformer(nn.Module):
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if self.post_trunk_norm is not None:
diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py
index f2999968669f6..a7df3509e3ecd 100644
--- a/vllm/model_executor/models/ernie45_moe.py
+++ b/vllm/model_executor/models/ernie45_moe.py
@@ -62,6 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import set_default_rope_theta
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (
@@ -232,9 +233,8 @@ class Ernie4_5_MoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
+ rope_parameters: dict[str, Any],
head_dim: int | None = None,
- rope_theta: float = 500000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 131072,
rms_norm_eps: float = 1e-05,
qkv_bias: bool = False,
@@ -266,7 +266,6 @@ class Ernie4_5_MoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -291,9 +290,8 @@ class Ernie4_5_MoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ rope_parameters=rope_parameters,
is_neox_style=False,
- rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
@@ -333,16 +331,14 @@ class Ernie4_5_MoeDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 500000)
- rope_scaling = getattr(config, "rope_scaling", None)
+ set_default_rope_theta(config, default_theta=500000)
max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
self.self_attn = Ernie4_5_MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=getattr(config, "head_dim", None),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "use_bias", False),
diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py
index daa5bf03ea4a9..07b34fbc8addb 100644
--- a/vllm/model_executor/models/ernie45_vl.py
+++ b/vllm/model_executor/models/ernie45_vl.py
@@ -214,7 +214,6 @@ class Ernie4_5_VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -259,7 +258,6 @@ class Ernie4_5_VisionAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -311,20 +309,6 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -404,14 +388,12 @@ class Ernie4_5_VisionBlock(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
@@ -562,18 +544,14 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
@@ -598,8 +576,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
if hidden_states.ndim == 2:
hidden_states = hidden_states.unsqueeze(dim=1)
- # pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
for i, blk in enumerate(self.blocks):
hidden_states = blk(
@@ -607,7 +585,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
final_output = self.ln(hidden_states)
diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py
index e8ef86f9b7f01..50e033d77606d 100644
--- a/vllm/model_executor/models/ernie45_vl_moe.py
+++ b/vllm/model_executor/models/ernie45_vl_moe.py
@@ -58,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import set_default_rope_theta
from .ernie45_moe import Ernie4_5_MoeMLP
from .interfaces import SupportsPP
@@ -91,9 +92,8 @@ class Ernie4_5_VLMoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
+ rope_parameters: dict[str, Any],
head_dim: int | None = None,
- rope_theta: float = 500000,
- rope_scaling: dict[str, Any] | None = None,
freq_allocation: int = 20,
max_position_embeddings: int = 131072,
rms_norm_eps: float = 1e-05,
@@ -126,7 +126,6 @@ class Ernie4_5_VLMoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -155,7 +154,7 @@ class Ernie4_5_VLMoeAttention(nn.Module):
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position_embeddings=max_position_embeddings,
- base=rope_theta,
+ base=rope_parameters["rope_theta"],
is_neox_style=False,
dtype=torch.get_default_dtype(),
mrope_section=[h_rope, w_rope, t_rope],
@@ -413,8 +412,7 @@ class Ernie4_5_VLMoeDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 500000)
- rope_scaling = getattr(config, "rope_scaling", None)
+ set_default_rope_theta(config, default_theta=500000)
freq_allocation = getattr(config, "freq_allocation", 20)
max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
@@ -423,8 +421,7 @@ class Ernie4_5_VLMoeDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=getattr(config, "head_dim", None),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
freq_allocation=freq_allocation,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py
index 6c56bfc433c7a..d13275488fe99 100644
--- a/vllm/model_executor/models/exaone.py
+++ b/vllm/model_executor/models/exaone.py
@@ -27,7 +27,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -113,8 +112,6 @@ class ExaoneAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -144,7 +141,6 @@ class ExaoneAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -173,8 +169,7 @@ class ExaoneAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
)
self.attn = Attention(
@@ -207,8 +202,6 @@ class ExaoneBlockAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -221,8 +214,6 @@ class ExaoneBlockAttention(nn.Module):
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=bias,
@@ -251,14 +242,6 @@ class ExaoneDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -272,8 +255,6 @@ class ExaoneDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py
index b89e168ada20e..70f3cce2b7c56 100644
--- a/vllm/model_executor/models/exaone4.py
+++ b/vllm/model_executor/models/exaone4.py
@@ -23,7 +23,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -52,6 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import set_default_rope_theta
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
@@ -110,8 +110,6 @@ class Exaone4Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 1000000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -141,7 +139,6 @@ class Exaone4Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -176,12 +173,12 @@ class Exaone4Attention(nn.Module):
# apply rotary embeddings to every layer in full attention models
self.apply_rope_all_layers = "sliding_attention" not in config.layer_types
+ set_default_rope_theta(config, default_theta=1000000)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
)
self.attn = Attention(
@@ -227,14 +224,6 @@ class Exaone4DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 1000000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -249,8 +238,6 @@ class Exaone4DecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index 85acdff3d96b4..dc2d51f340c8c 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -164,13 +164,12 @@ class FalconAttention(nn.Module):
)
if self.use_rotary:
- rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py
index b985847af5daf..9433f0d1b4a49 100644
--- a/vllm/model_executor/models/falcon_h1.py
+++ b/vllm/model_executor/models/falcon_h1.py
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import set_default_rope_theta
from .interfaces import (
HasInnerState,
@@ -214,8 +215,7 @@ class FalconH1AttentionDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
- rope_theta = getattr(config, "rope_theta", 1e11)
- rope_scaling = getattr(config, "rope_scaling", None)
+ set_default_rope_theta(config, default_theta=1e11)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -240,7 +240,6 @@ class FalconH1AttentionDecoderLayer(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "partial_rotary_factor"):
@@ -254,8 +253,7 @@ class FalconH1AttentionDecoderLayer(nn.Module):
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
- rope_scaling=rope_scaling,
- base=rope_theta,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
dtype=None, # see impl of get_rope
)
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 7aaae7c503b58..00c7f59a08094 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -20,6 +20,7 @@
from collections.abc import Iterable
from functools import cache
from itertools import islice
+from typing import Any
import torch
from torch import nn
@@ -127,8 +128,8 @@ class GemmaAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
head_dim: int,
+ rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@@ -153,7 +154,6 @@ class GemmaAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -176,7 +176,7 @@ class GemmaAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
@@ -218,7 +218,7 @@ class GemmaDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
- rope_theta=config.rope_theta,
+ rope_parameters=config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index 4d5d6cbb37c62..9b6cfe6932300 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -107,7 +107,6 @@ class Gemma2Attention(nn.Module):
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
- rope_theta: float,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
attn_logits_soft_cap: float | None = None,
@@ -134,7 +133,6 @@ class Gemma2Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -156,7 +154,7 @@ class Gemma2Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
@@ -206,7 +204,6 @@ class Gemma2DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
- rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping,
diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py
index 357e61a4e78bf..4ad6fc89dcaf2 100644
--- a/vllm/model_executor/models/gemma3.py
+++ b/vllm/model_executor/models/gemma3.py
@@ -155,25 +155,30 @@ class Gemma3Attention(nn.Module):
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
layer_idx = extract_layer_index(prefix)
- self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ layer_type = config.layer_types[layer_idx]
+ self.is_sliding = layer_type == "sliding_attention"
sliding_window = config.sliding_window if self.is_sliding else None
# Initialize the rotary embedding.
- if self.is_sliding:
- # Local attention. Override the values in config.json.
- self.rope_theta = config.rope_local_base_freq
- self.rope_scaling = {"rope_type": "default"}
+ if layer_type in config.rope_parameters:
+ # Transformers v5 rope config.
+ rope_parameters = config.rope_parameters[layer_type]
else:
+ # Transformers v4 rope config.
# Global attention. Use the values in config.json.
- self.rope_theta = config.rope_theta
- self.rope_scaling = config.rope_scaling
+ rope_parameters = config.rope_parameters
+ # Local attention. Override the values in config.json.
+ if self.is_sliding:
+ rope_parameters = dict(
+ rope_type="default", rope_theta=config.rope_local_base_freq
+ )
+
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=rope_parameters,
is_neox_style=True,
- rope_scaling=self.rope_scaling,
)
if getattr(config, "is_causal", True):
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index fe83c8b63b018..43c69e5e13992 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
- def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -644,142 +644,6 @@ class Gemma3ForConditionalGeneration(
return hidden_states
- def generate_attention_masks(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- mask_dtype: torch.dtype,
- ) -> dict[str, Any]:
- """Generate custom attention masks for Gemma3 multimodal inputs.
-
- This is called by V1 engine's gpu_model_runner during preprocessing
- to generate attention masks that allow bidirectional attention between
- image tokens while maintaining causal attention for text.
- """
- # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
- # This is a HACK. Fix this.
- start_indices = (positions == 0).cpu().nonzero()
- num_seqs = len(start_indices)
- seq_lens = []
- for i in range(num_seqs):
- start_idx = start_indices[i]
- end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
- seq_lens.append(end_idx - start_idx)
-
- global_attn_masks = []
- local_attn_masks = []
- start_idx = 0
- for seq_idx, seq_len in enumerate(seq_lens):
- end_idx = start_idx + seq_len
- input_token_ids = input_ids[start_idx:end_idx]
-
- # Find image token positions
- img_pos = input_token_ids == self.config.image_token_index
-
- start_idx = end_idx
-
- # Create a global causal mask
- global_attn_mask = torch.empty(
- 1,
- 1,
- seq_len,
- seq_len,
- dtype=mask_dtype,
- device=input_ids.device,
- )
- global_attn_mask.fill_(float("-inf"))
- # Fill the lower triangle with 0 (causal attention)
- global_attn_mask = global_attn_mask.triu(diagonal=1)
-
- # Enable bidirectional attention between image tokens
- img_mask = torch.zeros_like(global_attn_mask)
- img_mask[:, :, :, img_pos] += 1
- img_mask[:, :, img_pos, :] += 1
- global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
- global_attn_masks.append(global_attn_mask)
-
- # GGUF compatibility: config might be Gemma3TextConfig directly
- text_config = getattr(self.config, "text_config", self.config)
- sliding_window = text_config.sliding_window
- if sliding_window is not None:
- # Create a local causal mask with sliding window (1024)
- local_attn_mask = torch.ones_like(global_attn_mask)
- local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
- local_attn_mask = torch.where(
- local_attn_mask == 0, global_attn_mask, float("-inf")
- )
- local_attn_masks.append(local_attn_mask)
-
- return {
- "has_images": True,
- "seq_lens": seq_lens,
- "global_attn_masks": global_attn_masks,
- "local_attn_masks": local_attn_masks,
- }
-
- def prepare_attn_masks(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- mask_dtype: torch.dtype,
- **kwargs,
- ):
- kwargs["has_images"] = True
- # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
- # This is a HACK. Fix this.
- start_indices = (positions == 0).cpu().nonzero()
- num_seqs = len(start_indices)
- seq_lens = []
- for i in range(num_seqs):
- start_idx = start_indices[i].item()
- if i < num_seqs - 1:
- end_idx = start_indices[i + 1].item()
- else:
- end_idx = len(input_ids)
- seq_lens.append(end_idx - start_idx)
- kwargs["seq_lens"] = seq_lens
-
- global_attn_masks = []
- local_attn_masks = []
- start_idx = 0
- for seq_len in seq_lens:
- end_idx = start_idx + seq_len
- input_token_ids = input_ids[start_idx:end_idx]
- start_idx = end_idx
- # Create a global causal mask.
- global_attn_mask = torch.empty(
- 1,
- 1,
- seq_len,
- seq_len,
- dtype=mask_dtype,
- device=input_ids.device,
- )
- global_attn_mask.fill_(float("-inf"))
- # Fill the lower triangle with 0.
- global_attn_mask = global_attn_mask.triu(diagonal=1)
-
- # Consider the bidirectional attention between image tokens.
- img_mask = torch.zeros_like(global_attn_mask)
- img_pos = input_token_ids == self.config.image_token_index
- img_mask[:, :, :, img_pos] += 1
- img_mask[:, :, img_pos, :] += 1
- global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
- global_attn_masks.append(global_attn_mask)
-
- sliding_window = self.config.text_config.sliding_window
- if sliding_window is not None:
- # Create a local causal mask with sliding window (1024).
- local_attn_mask = torch.ones_like(global_attn_mask)
- local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
- local_attn_mask = torch.where(
- local_attn_mask == 0, global_attn_mask, float("-inf")
- )
- local_attn_masks.append(local_attn_mask)
- kwargs["global_attn_masks"] = global_attn_masks
- kwargs["local_attn_masks"] = local_attn_masks
- return kwargs
-
def compute_logits(
self,
hidden_states: torch.Tensor,
diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py
index 64443190f53ed..8f1447ba34a81 100644
--- a/vllm/model_executor/models/gemma3n.py
+++ b/vllm/model_executor/models/gemma3n.py
@@ -332,18 +332,21 @@ class Gemma3nAttention(nn.Module):
)
layer_idx = extract_layer_index(prefix)
- is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ layer_type = config.layer_types[layer_idx]
+ is_sliding = layer_type == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None
# Initialize the rotary embedding.
- if is_sliding:
- # Local attention. Override the values in config.json.
- rope_theta = config.rope_local_base_freq
- rope_scaling = {"rope_type": "default"}
+ if layer_type in config.rope_parameters:
+ # Transformers v5 rope config.
+ rope_parameters = config.rope_parameters[layer_type]
else:
+ # Transformers v4 rope config.
# Global attention. Use the values in config.json.
- rope_theta = config.rope_theta
- rope_scaling = config.rope_scaling
+ rope_parameters = config.rope_parameters.copy()
+ # Local attention. Override the values in config.json.
+ if is_sliding:
+ rope_parameters["rope_theta"] = config.rope_local_base_freq
first_kv_shared_layer_idx = (
config.num_hidden_layers - config.num_kv_shared_layers
@@ -383,9 +386,8 @@ class Gemma3nAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ rope_parameters=rope_parameters,
is_neox_style=True,
- rope_scaling=rope_scaling,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py
index faa0674a2e43d..f8ef3b0385fb1 100644
--- a/vllm/model_executor/models/glm4.py
+++ b/vllm/model_executor/models/glm4.py
@@ -57,10 +57,8 @@ class Glm4Attention(nn.Module):
max_position: int = 4096 * 32,
head_dim: int | None = None,
qkv_bias: bool = False,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
- rope_scaling: tuple | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
@@ -86,7 +84,6 @@ class Glm4Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
@@ -107,8 +104,7 @@ class Glm4Attention(nn.Module):
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
partial_rotary_factor=partial_rotary_factor,
is_neox_style=False,
)
@@ -150,8 +146,6 @@ class Glm4DecoderLayer(nn.Module):
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 1000000)
- rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Glm4Attention(
config=config,
@@ -159,12 +153,10 @@ class Glm4DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=AttentionType.DECODER,
)
diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py
index 2c2f45c2453ee..7e0370886884f 100644
--- a/vllm/model_executor/models/glm4_1v.py
+++ b/vllm/model_executor/models/glm4_1v.py
@@ -37,7 +37,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
-from transformers import BatchFeature
+from transformers import BatchFeature, Glm4vProcessor
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
@@ -309,7 +309,6 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -345,7 +344,6 @@ class Glm4vVisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -400,20 +398,6 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -461,7 +445,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -469,7 +452,6 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -703,7 +685,6 @@ class Glm4vVisionTransformer(nn.Module):
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
- base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList(
@@ -797,27 +778,21 @@ class Glm4vVisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
- cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
- cos_w = cos[pos_ids[:, 1]]
- sin_h = sin[pos_ids[:, 0]]
- sin_w = sin[pos_ids[:, 1]]
-
- cos_combined = torch.cat([cos_h, cos_w], dim=-1)
- sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ cos_combined = cos[pos_ids].flatten(1)
+ sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined, pos_ids
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ ) -> int | None:
+ max_seqlen = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -842,8 +817,9 @@ class Glm4vVisionTransformer(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
- # pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
x = self.embeddings(
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)
@@ -857,7 +833,6 @@ class Glm4vVisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
# adapter
@@ -1034,7 +1009,7 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
return max(max_frames_per_video, 1)
- def _get_video_second_idx(
+ def _get_video_second_idx_glm4v(
self, metadata: dict[str, Any], total_frames: int
) -> list[int]:
video_processor = self.get_video_processor()
@@ -1085,6 +1060,83 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
selected_timestamps.append(timestamps_list[idx])
return selected_timestamps
+ def _get_video_second_idx_glm46v(
+ self, metadata: dict[str, Any], total_frames: int
+ ) -> list[int]:
+ video_processor = self.get_video_processor()
+
+ video_fps = metadata["fps"]
+ meta_frames = metadata.get("total_num_frames", total_frames)
+ max_frame_idx = meta_frames - 1
+ duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)
+
+ do_sample_frames = metadata.get("do_sample_frames", True)
+ if not do_sample_frames:
+ frame_indices = metadata["frames_indices"]
+ else:
+ DYNAMIC_FPS_THRES = {30: 3, 300: 1, 2400: 0.5}
+ MAX_FRAME_COUNT_DYNAMIC = 640
+ MAX_DURATION = 2400
+
+ effective_duration = min(duration, MAX_DURATION)
+ if effective_duration <= 30:
+ target_fps = DYNAMIC_FPS_THRES[30]
+ elif effective_duration <= 300:
+ target_fps = DYNAMIC_FPS_THRES[300]
+ else:
+ target_fps = DYNAMIC_FPS_THRES[2400]
+
+ temporal_patch_size = getattr(video_processor, "temporal_patch_size", 1)
+ extract_t = int(effective_duration * target_fps * temporal_patch_size)
+ extract_t = min(extract_t, MAX_FRAME_COUNT_DYNAMIC)
+
+ duration_per_frame = 1 / video_fps
+ timestamps = [i * duration_per_frame for i in range(meta_frames)]
+ max_second = int(duration)
+
+ if meta_frames < extract_t:
+ frame_indices = np.linspace(
+ 0, meta_frames - 1, extract_t, dtype=int
+ ).tolist()
+ else:
+ frame_indices = []
+ current_second = 0.0
+ inv_fps = 1 / (temporal_patch_size * target_fps)
+ for frame_index in range(meta_frames):
+ if timestamps[frame_index] >= current_second:
+ current_second += inv_fps
+ frame_indices.append(frame_index)
+ if current_second >= max_second:
+ break
+
+ if len(frame_indices) < extract_t:
+ if len(frame_indices) == 0:
+ start, end = 0, max(meta_frames - 1, 0)
+ else:
+ start, end = frame_indices[0], frame_indices[-1]
+ frame_indices = np.linspace(start, end, extract_t, dtype=int).tolist()
+ elif len(frame_indices) > extract_t:
+ frame_indices = np.linspace(
+ 0, meta_frames - 1, extract_t, dtype=int
+ ).tolist()
+
+ seen, uniq = set(), []
+ for idx in frame_indices:
+ if idx not in seen:
+ seen.add(idx)
+ uniq.append(idx)
+
+ if len(uniq) & 1:
+ uniq.append(uniq[-1])
+
+ frame_indices = uniq
+ full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
+ timestamps_list = full_second_idxs[::2]
+ selected_timestamps = []
+ for idx in range(len(timestamps_list)):
+ selected_timestamps.append(timestamps_list[idx])
+ return selected_timestamps
+
def _construct_video_placeholder(
self,
video_array: np.ndarray,
@@ -1103,9 +1155,18 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
merge_length = image_processor.merge_size**2
assert isinstance(grid_thw, torch.Tensor)
- timestamps = self._get_video_second_idx(metadata, len(video_array))
+ timestamps = (
+ self._get_video_second_idx_glm4v(metadata, len(video_array))
+ if isinstance(hf_processor, Glm4vProcessor)
+ else self._get_video_second_idx_glm46v(metadata, len(video_array))
+ )
+
+ timestamp_format = (
+ "{}" if isinstance(hf_processor, Glm4vProcessor) else "{:.1f} seconds"
+ )
frames_idx_token = [
- tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps
+ tokenizer.encode(timestamp_format.format(i), add_special_tokens=False)
+ for i in timestamps
]
T, H, W = grid_thw
num_tokens_per_frame = int(H * W) // merge_length
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
index 1422dbe9b3cd0..5aa51af54a00b 100644
--- a/vllm/model_executor/models/glm4_moe.py
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -26,7 +26,6 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -233,8 +232,6 @@ class Glm4MoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 131072,
head_dim: int | None = None,
rms_norm_eps: float = 1e-05,
@@ -264,7 +261,6 @@ class Glm4MoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = use_qk_norm
@@ -291,8 +287,7 @@ class Glm4MoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
partial_rotary_factor=partial_rotary_factor,
)
self.attn = Attention(
@@ -341,8 +336,6 @@ class Glm4MoeDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
@@ -354,8 +347,6 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index e416ecde0c1e0..e94de8952fa63 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -95,13 +95,12 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
- rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ rope_parameters=config.rope_parameters,
is_neox_style=False,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index af0c9209231cb..815c2fba4d9fe 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -92,13 +92,12 @@ class GPTNeoXAttention(nn.Module):
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
- rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index 7df3b087ccb88..1bc0ad38765d5 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_dp_group,
get_ep_group,
+ get_pcp_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -67,16 +68,17 @@ class OAIAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
- base=config.rope_theta,
dtype=torch.float32,
- rope_scaling={
+ rope_parameters={
+ "rope_theta": config.rope_parameters["rope_theta"],
"rope_type": "yarn",
- "factor": config.rope_scaling["factor"],
- "original_max_position_embeddings": config.rope_scaling[
+ "factor": config.rope_parameters["factor"],
+ "original_max_position_embeddings": config.rope_parameters[
"original_max_position_embeddings"
],
- "beta_fast": config.rope_scaling["beta_fast"],
- "beta_slow": config.rope_scaling["beta_slow"],
+ "beta_fast": config.rope_parameters["beta_fast"],
+ "beta_slow": config.rope_parameters["beta_slow"],
+ "truncate": config.rope_parameters.get("truncate", True),
},
is_neox_style=True,
)
@@ -90,7 +92,6 @@ class OAIAttention(nn.Module):
self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5
- self.rope_theta = config.rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
@@ -323,10 +324,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
- tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
+ tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
+ pcp_size=get_pcp_group().world_size,
+ pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
@@ -508,10 +511,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
- tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
+ tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
+ pcp_size=get_pcp_group().world_size,
+ pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
@@ -651,6 +656,7 @@ class GptOssModel(nn.Module):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
+ is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py
index c44b4021471ef..cd7ce2fc8f00a 100644
--- a/vllm/model_executor/models/granite.py
+++ b/vllm/model_executor/models/granite.py
@@ -26,7 +26,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -47,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
- DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@@ -112,8 +110,6 @@ class GraniteAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -143,7 +139,6 @@ class GraniteAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.attention_multiplier
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -167,8 +162,7 @@ class GraniteAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -204,14 +198,6 @@ class GraniteDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.residual_multiplier = config.residual_multiplier
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -225,8 +211,6 @@ class GraniteDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
@@ -276,29 +260,16 @@ class GraniteModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
- lora_vocab = (
- (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
- if lora_config
- else 0
- )
- self.vocab_size = config.vocab_size + lora_vocab
- self.org_vocab_size = config.vocab_size
+
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
- self.vocab_size,
+ config.vocab_size,
config.hidden_size,
- org_num_embeddings=config.vocab_size,
- padding_size=DEFAULT_VOCAB_PADDING_SIZE
- # We need bigger padding if using lora for kernel
- # compatibility
- if not lora_config
- else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
else:
@@ -435,28 +406,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
self.config = config
- self.lora_config = lora_config
+
self.quant_config = quant_config
self.model = GraniteModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
- self.unpadded_vocab_size = config.vocab_size
- if lora_config:
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
- self.unpadded_vocab_size,
+ config.vocab_size,
config.hidden_size,
- org_num_embeddings=config.vocab_size,
- padding_size=DEFAULT_VOCAB_PADDING_SIZE
- # We need bigger padding if using lora for kernel
- # compatibility
- if not lora_config
- else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
@@ -468,7 +429,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor(
- self.unpadded_vocab_size, config.vocab_size, scale=logit_scale
+ config.vocab_size, scale=logit_scale
)
else:
self.lm_head = PPMissingLayer()
diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py
index 5c6759ded0669..8f4139d63c3f6 100644
--- a/vllm/model_executor/models/granitemoe.py
+++ b/vllm/model_executor/models/granitemoe.py
@@ -141,8 +141,7 @@ class GraniteMoeAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
attention_multiplier: float | None = None,
@@ -172,7 +171,6 @@ class GraniteMoeAttention(nn.Module):
if attention_multiplier is not None
else self.head_dim**-1
)
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -194,9 +192,8 @@ class GraniteMoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=int(self.rope_theta),
+ rope_parameters=rope_parameters,
is_neox_style=True,
- rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
@@ -235,16 +232,12 @@ class GraniteMoeDecoderLayer(nn.Module):
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py
index a340112ec62ae..9d5eeef198a61 100644
--- a/vllm/model_executor/models/granitemoehybrid.py
+++ b/vllm/model_executor/models/granitemoehybrid.py
@@ -273,10 +273,7 @@ class GraniteMoeHybridAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
- base=int(config.rope_theta),
- rope_scaling=config.rope_scaling
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None
- else None,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
else:
diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py
index 926c539af33be..fd346db7e35aa 100644
--- a/vllm/model_executor/models/granitemoeshared.py
+++ b/vllm/model_executor/models/granitemoeshared.py
@@ -84,16 +84,12 @@ class GraniteMoeSharedDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py
index 9dc231863f74f..4bf23cd6fd19a 100644
--- a/vllm/model_executor/models/grok1.py
+++ b/vllm/model_executor/models/grok1.py
@@ -25,6 +25,7 @@
from collections.abc import Iterable
from itertools import islice
+from typing import Any
import torch
import torch.nn.functional as F
@@ -134,7 +135,7 @@ class Grok1Attention(nn.Module):
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
+ rope_parameters: dict[str, Any] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@@ -161,7 +162,6 @@ class Grok1Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -183,7 +183,7 @@ class Grok1Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=int(self.rope_theta),
+ rope_parameters=rope_parameters,
is_neox_style=True,
)
@@ -234,15 +234,12 @@ class Grok1DecoderLayer(nn.Module):
if not self.use_fp8 and hasattr(quant_config, "is_fp8"):
self.use_fp8 = quant_config.is_fp8
- # Requires transformers > 4.32.0
- # Default rope_theta value if not in config
- rope_theta = 10000
self.attn = Grok1Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
+ rope_parameters=config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py
index 1eadcbe67ade3..53fb444ed622d 100644
--- a/vllm/model_executor/models/hunyuan_v1.py
+++ b/vllm/model_executor/models/hunyuan_v1.py
@@ -27,7 +27,6 @@
import typing
from collections.abc import Callable, Iterable
from itertools import islice
-from typing import Any
import regex as re
import torch
@@ -142,8 +141,6 @@ class HunYuanAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -177,7 +174,6 @@ class HunYuanAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.layer_id = layer_id
@@ -204,8 +200,7 @@ class HunYuanAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
@@ -254,8 +249,6 @@ class HunYuanCrossAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -289,7 +282,6 @@ class HunYuanCrossAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.layer_id = layer_id
@@ -314,8 +306,7 @@ class HunYuanCrossAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
@@ -494,14 +485,6 @@ class HunYuanDecoderLayer(nn.Module):
if isinstance(config.intermediate_size, int)
else config.intermediate_size[layer_id]
)
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
@@ -520,8 +503,6 @@ class HunYuanDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
@@ -537,8 +518,6 @@ class HunYuanDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
@@ -597,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module):
return hidden_states, residual, ori_kv_states
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py
new file mode 100644
index 0000000000000..e83addd0c092f
--- /dev/null
+++ b/vllm/model_executor/models/hunyuan_vision.py
@@ -0,0 +1,1028 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# coding=utf-8
+# Copyright 2025 The HunYuan team.
+# Copyright 2025 The vLLM team.
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only HunYuan-VL model compatible with HuggingFace weights."""
+
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from functools import partial
+from typing import Annotated, Any, Literal, TypeAlias
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import BatchFeature
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.layer import MultiHeadAttention
+from vllm.config import MultiModalConfig, VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ ImageItem,
+ ModalityData,
+ MultiModalDataDict,
+ MultiModalFeatureSpec,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import (
+ DictEmbeddingItems,
+ ImageSize,
+ MultiModalDataItems,
+ MultiModalDataParser,
+)
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLVisionConfig,
+)
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+ MultiModalEmbeddings,
+ SupportsLoRA,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+)
+from .utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+# === Vision Inputs === #
+
+
+class HunYuanVLImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - np: Number of patches
+ - ni: Number of images
+ - cps: Number of channels * patch_size * patch_size
+ """
+
+ type: Literal["pixel_values"]
+
+ pixel_values: Annotated[
+ torch.Tensor,
+ TensorShape("np", "cps"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+class HunYuanVLImageEmbeddingInputs(TensorSchema):
+ """
+ Dimensions:
+ - nf: Number of image features
+ - hs: Hidden size
+ - ni: Number of images
+ """
+
+ type: Literal["image_embeds"]
+
+ image_embeds: Annotated[
+ torch.Tensor,
+ TensorShape("nf", "hs"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+HunYuanVLImageInputs: TypeAlias = (
+ HunYuanVLImagePixelInputs | HunYuanVLImageEmbeddingInputs
+)
+
+# === Vision Encoder === #
+
+
+class HunYuanVisionMLP(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = True,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.dense_h_to_4h = ColumnParallelLinear(
+ in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_h_to_4h",
+ disable_tp=use_data_parallel,
+ )
+ self.dense_4h_to_h = RowParallelLinear(
+ hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_4h_to_h",
+ disable_tp=use_data_parallel,
+ )
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ x_up, _ = self.dense_h_to_4h(x)
+ x_down, _ = self.dense_4h_to_h(self.act_fn(x_up))
+ return x_down
+
+
+class HunYuanVisionAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ self.tp_size = (
+ 1
+ if use_data_parallel
+ else parallel_state.get_tensor_model_parallel_world_size()
+ )
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads
+ )
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, self.tp_size
+ )
+
+ self.qkv = QKVParallelLinear(
+ hidden_size=embed_dim,
+ head_size=self.hidden_size_per_attention_head,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv",
+ disable_tp=use_data_parallel,
+ )
+
+ self.o_proj = RowParallelLinear(
+ input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ disable_tp=use_data_parallel,
+ )
+
+ self.scale = self.hidden_size_per_attention_head**-0.5
+ self.attn = MultiHeadAttention(
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ self.scale,
+ prefix=f"{prefix}.attn",
+ multimodal_config=multimodal_config,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv(x)
+ q, k, v = qkv.chunk(3, dim=-1)
+ out = self.attn(q, k, v)
+ output, _ = self.o_proj(out)
+ return output
+
+
+class HunYuanVisionBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ norm_layer: Callable[[int], nn.Module] | None = None,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.input_layernorm = norm_layer(dim)
+ self.post_attention_layernorm = norm_layer(dim)
+ self.self_attn = HunYuanVisionAttention(
+ embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.self_attn",
+ use_data_parallel=use_data_parallel,
+ )
+ self.mlp = HunYuanVisionMLP(
+ dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ use_data_parallel=use_data_parallel,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ x = x + self.self_attn(self.input_layernorm(x))
+ x = x + self.mlp(self.post_attention_layernorm(x))
+ return x
+
+
+class HunYuanVisionPatchEmbed(nn.Module):
+ def __init__(self, config: HunYuanVLVisionConfig):
+ super().__init__()
+
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+ self.num_channels = config.num_channels
+ self.spatial_merge_size = config.spatial_merge_size
+ self.interpolate_mode = config.interpolate_mode
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.max_num_patches = (config.max_image_size // self.patch_size) ** 2
+
+ self.num_positions = self.max_num_patches + 1
+ self.position_edge = int(self.num_positions**0.5)
+ # first token is cls token, skip it
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ self.patch_pos_embed = None
+
+ def forward(
+ self, pixel_values: torch.Tensor, grid_thw: list[list[int]]
+ ) -> torch.Tensor:
+ num_patches = pixel_values.size(0)
+ pixel_values = pixel_values.reshape(
+ num_patches, self.num_channels, self.patch_size, self.patch_size
+ )
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ patch_embeds = patch_embeds.squeeze(-1).squeeze(-1).unsqueeze(0)
+
+ if self.patch_pos_embed is None:
+ patch_pos_shape = (
+ 1,
+ self.position_edge,
+ self.position_edge,
+ self.embed_dim,
+ )
+ self.patch_pos_embed = (
+ self.position_embedding.weight[1:, :]
+ .reshape(patch_pos_shape)
+ .permute(0, 3, 1, 2)
+ .float()
+ )
+
+ patch_pos_embed_list = []
+ for grid in grid_thw:
+ _, h0, w0 = grid
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ self.patch_pos_embed,
+ scale_factor=(h0 / self.position_edge, w0 / self.position_edge),
+ mode=self.interpolate_mode,
+ align_corners=False,
+ )
+
+ patch_pos_embed = (
+ patch_pos_embed.reshape(self.embed_dim, -1)
+ .transpose(0, 1)
+ .unsqueeze(0)
+ .to(patch_embeds.dtype)
+ )
+ patch_pos_embed_list.append(patch_pos_embed)
+
+ patch_pos_embed = torch.cat(patch_pos_embed_list, dim=1)
+ embeddings = patch_embeds + patch_pos_embed
+
+ return embeddings
+
+
+class HunYuanVisionPatchMerger(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ spatial_merge_size=2,
+ rms_norm_eps=1e-5,
+ prefix="",
+ ):
+ super().__init__()
+ self.spatial_merge_size = spatial_merge_size
+ embed_std = out_channels**-0.5
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ in_channels * 2,
+ kernel_size=spatial_merge_size,
+ stride=spatial_merge_size,
+ ),
+ nn.GELU(),
+ nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
+ )
+ self.mlp = nn.Linear(in_channels * 4, out_channels)
+
+ self.image_newline = nn.Parameter(torch.randn(in_channels * 4) * embed_std)
+ self.image_begin = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_end = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_sep = nn.Parameter(torch.randn(out_channels) * embed_std)
+
+ self.before_rms = RMSNorm(in_channels, eps=rms_norm_eps)
+ self.after_rms = RMSNorm(out_channels, eps=rms_norm_eps)
+
+ def forward(self, x, size=(16, 16)):
+ x = self.before_rms(x)
+
+ h, w = size
+ dtype = x.dtype
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
+
+ x = self.proj(x) # b,c,h,w
+ b, c, h, w = x.shape
+ x = torch.cat(
+ [x, self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype)],
+ dim=-1,
+ )
+ x = x.reshape(b, c, -1).permute(0, 2, 1)
+ x = self.mlp(x)
+
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ x = torch.cat([begin, x, end], dim=1)
+
+ return self.after_rms(x)
+
+
+class HunYuanVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ vision_config: HunYuanVLVisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ multimodal_config: MultiModalConfig | None = None,
+ attn_backend_override: AttentionBackendEnum | None = None,
+ ) -> None:
+ super().__init__()
+
+ num_hidden_layers = vision_config.num_hidden_layers
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_attention_heads
+ self.spatial_merge_size = vision_config.spatial_merge_size
+
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("HunYuanVisionPatchEmbed"):
+ self.embeddings = HunYuanVisionPatchEmbed(vision_config)
+
+ norm_layer = partial(nn.LayerNorm, eps=vision_config.rms_norm_eps)
+
+ with set_model_tag("HunYuanVisionBlock"):
+ self.layers = nn.ModuleList(
+ [
+ HunYuanVisionBlock(
+ dim=vision_config.hidden_size,
+ num_heads=vision_config.num_attention_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=get_act_fn(vision_config.hidden_act),
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ use_data_parallel=use_data_parallel,
+ )
+ for layer_idx in range(num_hidden_layers)
+ ]
+ )
+
+ with set_model_tag("HunYuanVisionPatchMerger"):
+ self.perceive = HunYuanVisionPatchMerger(
+ vision_config.hidden_size,
+ vision_config.out_hidden_size,
+ spatial_merge_size=vision_config.spatial_merge_size,
+ rms_norm_eps=vision_config.rms_norm_eps,
+ prefix=f"{prefix}.perceive",
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.embeddings.patch_embedding.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.embeddings.patch_embedding.weight.device
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: list[list[int]],
+ ) -> torch.Tensor:
+ # patchify
+ seq_len = x.size(0)
+ cu_seqlens: list = [0]
+
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.embeddings(hidden_states, grid_thw)
+
+ for t, h, w in grid_thw:
+ t, h, w = int(t), int(h), int(w)
+ cu_seqlens.append(h * w)
+
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
+ cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
+
+ cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
+
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ hidden_states = hidden_states.unsqueeze(0)
+ for layer_num, layer in enumerate(self.layers):
+ hidden_states = layer(hidden_states)
+
+ # adapter
+ split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ split_items = hidden_states.split(split_lengths, dim=1)
+ image_embeds_list = []
+ for grid, split_item in zip(grid_thw, split_items):
+ image_embeds_list.append(
+ self.perceive(split_item.contiguous(), size=grid[1:]).squeeze(0)
+ )
+
+ return image_embeds_list
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv", ".q_proj", "q"),
+ (".qkv", ".k_proj", "k"),
+ (".qkv", ".v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ image_grid_sizes = image_grid_thw.prod(-1)
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_grid_thw=MultiModalFieldConfig.batched("image"),
+ )
+
+
+class HunYuanVLMultiModalDataParser(MultiModalDataParser):
+ def _parse_image_data(
+ self,
+ data: dict[str, torch.Tensor] | ModalityData[ImageItem],
+ ):
+ if isinstance(data, dict):
+ return DictEmbeddingItems(
+ data,
+ modality="image",
+ required_fields={"image_embeds", "image_grid_thw"},
+ fields_factory=_hunyuan_vl_field_config,
+ )
+
+ return super()._parse_image_data(data)
+
+
+class HunYuanVLProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(HunYuanVLConfig)
+
+ def get_hf_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.ctx.get_hf_processor(
+ HunYuanVLProcessor,
+ use_fast=kwargs.pop("use_fast", True),
+ **kwargs,
+ )
+
+ def get_image_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ max_image_tokens = self.get_max_image_tokens()
+ # TODO: support video
+ max_video_tokens = 0
+ return {"image": max_image_tokens, "video": max_video_tokens}
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ do_resize: bool = True,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> tuple[ImageSize, int]:
+ if image_processor is None:
+ image_processor = self.get_image_processor()
+
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ patch_size = vision_config.patch_size
+ spatial_merge_size = vision_config.spatial_merge_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * spatial_merge_size,
+ min_pixels=image_processor.min_pixels,
+ max_pixels=image_processor.max_pixels,
+ )
+ preprocessed_size = ImageSize(width=resized_width, height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width, height=image_height)
+
+ grid_t = 1
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_vision_tokens = (
+ grid_t * grid_h // spatial_merge_size * (grid_w // spatial_merge_size + 1)
+ + 2
+ )
+
+ return preprocessed_size, num_vision_tokens
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> int:
+ _, num_image_tokens = self._get_vision_info(
+ image_width=image_width,
+ image_height=image_height,
+ image_processor=image_processor,
+ )
+ return num_image_tokens
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ max_image_size, _ = self._get_vision_info(
+ image_width=512,
+ image_height=8192,
+ image_processor=None,
+ )
+ return max_image_size
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ image_processor=None,
+ )
+
+
+class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ hf_processor = self.info.get_hf_processor()
+ image_token: str = hf_processor.image_token
+
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 1)
+
+ target_width, target_height = self.info.get_image_size_with_most_features()
+
+ return {
+ "image": self._get_dummy_images(
+ width=target_width, height=target_height, num_images=num_images
+ ),
+ }
+
+
+class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]):
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return HunYuanVLMultiModalDataParser()
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ return self.info.ctx.call_hf_processor(
+ self.info.get_hf_processor(**mm_kwargs),
+ dict(text=prompt, **mm_data),
+ dict(**mm_kwargs, **tok_kwargs),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
+
+ placeholder = {
+ "image": hf_processor.image_token_id,
+ }
+
+ merge_size = image_processor.merge_size
+
+ def get_replacement_hunyuan_vl(item_idx: int, modality: str):
+ out_item = out_mm_kwargs[modality][item_idx]
+ grid_thw = out_item[f"{modality}_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+
+ _, grid_h, grid_w = grid_thw
+ num_tokens = (int(grid_h) // merge_size) * (
+ int(grid_w) // merge_size + 1
+ ) + 2
+ return [placeholder[modality]] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=[placeholder[modality]],
+ replacement=partial(get_replacement_hunyuan_vl, modality=modality),
+ )
+ for modality in ("image",)
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _hunyuan_vl_field_config(hf_inputs)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ HunYuanVLMultiModalProcessor,
+ info=HunYuanVLProcessingInfo,
+ dummy_inputs=HunYuanVLDummyInputsBuilder,
+)
+class HunYuanVLForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsLoRA,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+):
+ multimodal_cpu_fields = {"image_grid_thw"}
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # mapping for new names in checkpoint saved after transformers v4.52
+ "vit.vit.": "visual.",
+ "vit.": "visual.",
+ "model.": "language_model.model.",
+ }
+ )
+
+ supports_encoder_tp_data = True
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list[MultiModalFeatureSpec],
+ ) -> torch.Tensor:
+ kwargs = MultiModalFeatureSpec.gather_kwargs(
+ mm_features,
+ {"image_grid_thw"},
+ )
+ image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
+
+ hf_config = self.config
+ image_start_token_id = hf_config.image_start_token_id
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
+ xd_num = len(hf_config.rope_scaling["xdrope_section"])
+
+ input_tokens_tensor = torch.tensor(input_tokens)
+ image_start_indices = torch.argwhere(
+ input_tokens_tensor == image_start_token_id
+ ).squeeze(1)
+
+ p_index = torch.arange(len(input_tokens_tensor))
+ w_index = torch.arange(len(input_tokens_tensor))
+ h_index = torch.arange(len(input_tokens_tensor))
+ t_index = torch.arange(len(input_tokens_tensor))
+ for image_index in range(len(image_start_indices)):
+ # +1 : first image_token, +2: for xdrope positions
+ pos = image_start_indices[image_index] + 2
+ t, h, w = image_grid_thw[image_index]
+ _, llm_grid_h, llm_grid_w = (
+ t,
+ h // spatial_merge_size,
+ w // spatial_merge_size,
+ )
+
+ token_num = (llm_grid_w + 1) * llm_grid_h
+ w_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_w + 1)
+ .reshape(1, -1)
+ .expand(llm_grid_h, -1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_h)
+ .reshape(-1, 1)
+ .expand(-1, llm_grid_w + 1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num] = 0
+
+ if xd_num == 4:
+ llm_positions = torch.stack([p_index, w_index, h_index, t_index])
+ elif xd_num == 3:
+ llm_positions = torch.stack([w_index, h_index, t_index])
+
+ return llm_positions
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: HunYuanVLConfig = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ if multimodal_config.get_limit_per_prompt("image"):
+ attn_backend_override = (
+ multimodal_config.mm_encoder_attn_backend
+ if multimodal_config is not None
+ else None
+ )
+ self.visual = HunYuanVisionTransformer(
+ config.vision_config,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ multimodal_config=multimodal_config,
+ attn_backend_override=attn_backend_override,
+ )
+ else:
+ self.visual = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model.model"),
+ architectures=[
+ "HunYuanDenseV1ForCausalLM",
+ "HunYuanMoEV1ForCausalLM",
+ ],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object
+ ) -> HunYuanVLImageInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ # TODO: refine
+ if isinstance(pixel_values, list):
+ pixel_values = torch.cat(pixel_values, dim=0)
+ if len(pixel_values.shape) == 3:
+ last_dim = pixel_values.shape[-1]
+ pixel_values = pixel_values.reshape(-1, last_dim)
+ image_grid_thw = image_grid_thw.reshape(-1, 3)
+
+ if pixel_values is not None:
+ return HunYuanVLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ if image_embeds is not None:
+ return HunYuanVLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw,
+ )
+
+ def _process_image_input(
+ self, image_input: HunYuanVLImageInputs
+ ) -> tuple[torch.Tensor, ...]:
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+ grid_thw_list = grid_thw.tolist()
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"]
+
+ # TODO: use_data_parallel (split image_embeds in visual)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
+
+ return image_embeds
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if (
+ input_key in ("pixel_values", "image_embeds")
+ and "image" not in mm_input_by_modality
+ ):
+ mm_input_by_modality["image"] = self._parse_and_validate_image_input(
+ **kwargs
+ )
+ return mm_input_by_modality
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ image_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings += tuple(image_embeddings)
+ return multimodal_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None,
+ inputs_embeds: torch.Tensor | None,
+ **kwargs: object,
+ ) -> torch.Tensor | IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model.model",
+ connector="visual.perceive",
+ tower_model="visual",
+ )
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index dc4caf2f02f9d..6f6ce32538b71 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -586,13 +586,11 @@ class IsHybrid(Protocol):
def get_mamba_state_shape_from_config(
cls,
vllm_config: VllmConfig,
- use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
- use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
@@ -1049,7 +1047,7 @@ class SupportsMRoPE(Protocol):
supports_mrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports M-RoPE.
-
+
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
@@ -1090,3 +1088,52 @@ def supports_mrope(
model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
return isinstance(model, SupportsMRoPE)
+
+
+@runtime_checkable
+class SupportsXDRoPE(Protocol):
+ """The interface required for all models that support XD-RoPE."""
+
+ supports_xdrope: ClassVar[Literal[True]] = True
+ """
+ A flag that indicates this model supports XD-RoPE.
+
+ Note:
+ There is no need to redefine this flag if this class is in the
+ XDRope of your model class.
+ """
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list["MultiModalFeatureSpec"],
+ ) -> torch.Tensor:
+ """
+ Get XD-RoPE input positions and delta value for this specific model.
+
+ This method should be implemented by each model that supports XD-RoPE
+ to provide model-specific logic for computing input positions.
+
+ Args:
+ input_tokens: List of input token IDs
+ mm_features: Information about each multi-modal data item
+
+ Returns:
+ llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
+ 4D(P/W/H/T) or 3D(W/H/T) positions.
+ """
+ ...
+
+
+@overload
+def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
+
+
+@overload
+def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
+
+
+def supports_xdrope(
+ model: type[object] | object,
+) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
+ return isinstance(model, SupportsXDRoPE)
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index 60fbeb842dd4b..dc8f821bd134f 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -91,8 +91,7 @@ class InternLM2Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -120,7 +119,6 @@ class InternLM2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.wqkv = QKVParallelLinear(
@@ -144,8 +142,7 @@ class InternLM2Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -204,15 +201,12 @@ class InternLMDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py
index 6dc081e34157b..a57db82242af9 100644
--- a/vllm/model_executor/models/internlm2_ve.py
+++ b/vllm/model_executor/models/internlm2_ve.py
@@ -30,15 +30,12 @@ class InternLM2VEDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.attention = InternLM2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py
index 8fc3db296aa79..302260b952992 100644
--- a/vllm/model_executor/models/keye.py
+++ b/vllm/model_executor/models/keye.py
@@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
@@ -424,7 +425,7 @@ class KeyeSiglipAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
- AttentionBackendEnum.XFORMERS,
+ AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -451,7 +452,6 @@ class KeyeSiglipAttention(nn.Module):
)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
batch_size = q.shape[0]
if rope_emb is None:
@@ -498,17 +498,21 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
+ elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
+ outputs = []
+ for i in range(1, len(cu_seqlens)):
+ start_idx = cu_seqlens[i - 1]
+ end_idx = cu_seqlens[i]
+ q_i = q[:, start_idx:end_idx]
+ k_i = k[:, start_idx:end_idx]
+ v_i = v[:, start_idx:end_idx]
+ q_i, k_i, v_i = (
+ rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
+ )
+ output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
+ output_i = rearrange(output_i, "b h s d -> b s h d ")
+ outputs.append(output_i)
+ context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py
index f3675075a48f4..4562b2202c5ec 100644
--- a/vllm/model_executor/models/kimi_linear.py
+++ b/vllm/model_executor/models/kimi_linear.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
-from typing import Any
import torch
from torch import nn
@@ -190,9 +189,7 @@ class KimiMLAAttention(nn.Module):
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
- rope_theta: float = 10000,
use_nope: bool = False,
- rope_scaling: dict[str, Any] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@@ -210,11 +207,9 @@ class KimiMLAAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
- self.rope_theta = rope_theta
self.use_nope = use_nope
assert self.use_nope is True
assert self.q_lora_rank is None
- assert rope_scaling is None
assert num_heads % tp_size == 0
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py
index aeb25602f11a4..74bdde27ece5c 100644
--- a/vllm/model_executor/models/lfm2.py
+++ b/vllm/model_executor/models/lfm2.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
import torch.nn as nn
@@ -96,8 +95,6 @@ class Lfm2Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -126,7 +123,6 @@ class Lfm2Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -149,8 +145,7 @@ class Lfm2Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
@@ -199,14 +194,6 @@ class Lfm2AttentionDecoderLayer(nn.Module):
self.config = config
self.layer_idx = layer_idx
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = Lfm2Attention(
@@ -215,8 +202,6 @@ class Lfm2AttentionDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py
index 6b7b5564ee989..c088a08211527 100644
--- a/vllm/model_executor/models/lfm2_moe.py
+++ b/vllm/model_executor/models/lfm2_moe.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
import torch.nn as nn
@@ -189,8 +188,6 @@ class Lfm2MoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -219,7 +216,6 @@ class Lfm2MoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -242,8 +238,7 @@ class Lfm2MoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
@@ -293,14 +288,6 @@ class Lfm2MoeAttentionDecoderLayer(nn.Module):
self.config = config
self.layer_idx = layer_idx
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = Lfm2MoeAttention(
@@ -309,8 +296,6 @@ class Lfm2MoeAttentionDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 0a3f37c30ab5f..eebb9e07fa89d 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -26,7 +26,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -48,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
- DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@@ -120,8 +118,6 @@ class LlamaAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -157,7 +153,6 @@ class LlamaAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
@@ -186,9 +181,7 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
- self._init_rotary_emb(
- config, rope_scaling=rope_scaling, quant_config=quant_config
- )
+ self._init_rotary_emb(config, quant_config=quant_config)
sliding_window = None
if layer_types := getattr(config, "layer_types", None):
@@ -258,7 +251,6 @@ class LlamaAttention(nn.Module):
def _init_rotary_emb(
self,
config: LlamaConfig,
- rope_scaling: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = True
@@ -270,8 +262,7 @@ class LlamaAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
@@ -291,14 +282,6 @@ class LlamaDecoderLayer(nn.Module):
quant_config = self.get_quant_config(vllm_config)
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -326,8 +309,6 @@ class LlamaDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
@@ -373,7 +354,17 @@ class LlamaDecoderLayer(nn.Module):
return vllm_config.quant_config
-@support_torch_compile
+def llama_model_invariants(
+ input_ids, positions, intermediate_tensors=None, inputs_embeds=None
+):
+ """Shape invariants for Llama model compilation, those are translated to
+ runtime assertions for unbacked dynamic shapes and are compiled away for
+ backed"""
+ if input_ids is not None:
+ torch._check(positions.size()[0] == input_ids.size()[0])
+
+
+@support_torch_compile(shape_invariants=llama_model_invariants)
class LlamaModel(nn.Module):
def __init__(
self,
@@ -386,24 +377,18 @@ class LlamaModel(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
- lora_vocab = (
- (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
- if lora_config
- else 0
- )
- self.vocab_size = config.vocab_size + lora_vocab
- self.org_vocab_size = config.vocab_size
+
+ self.vocab_size = config.vocab_size
+
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
- org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
@@ -580,9 +565,7 @@ class LlamaForCausalLM(
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
self.config = config
- self.lora_config = lora_config
self.model = self._init_model(
vllm_config=vllm_config,
@@ -591,20 +574,9 @@ class LlamaForCausalLM(
)
if get_pp_group().is_last_rank:
- self.unpadded_vocab_size = config.vocab_size
- if lora_config:
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
- self.unpadded_vocab_size,
+ config.vocab_size,
config.hidden_size,
- org_num_embeddings=config.vocab_size,
- padding_size=(
- DEFAULT_VOCAB_PADDING_SIZE
- # We need bigger padding if using lora for kernel
- # compatibility
- if not lora_config
- else lora_config.lora_vocab_padding_size
- ),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
@@ -613,7 +585,7 @@ class LlamaForCausalLM(
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
- self.unpadded_vocab_size, config.vocab_size, logit_scale
+ config.vocab_size, scale=logit_scale
)
else:
self.lm_head = PPMissingLayer()
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
index a7e0732ec71e2..e1bdfc3405f70 100644
--- a/vllm/model_executor/models/llama4.py
+++ b/vllm/model_executor/models/llama4.py
@@ -19,7 +19,6 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable
-from typing import Any
import torch
from torch import nn
@@ -54,6 +53,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (
AutoWeightsLoader,
+ PPMissingLayer,
extract_layer_index,
fast_topk,
is_pp_missing_parameter,
@@ -171,8 +171,6 @@ class Llama4Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -208,7 +206,6 @@ class Llama4Attention(nn.Module):
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.qk_norm = (
@@ -248,8 +245,7 @@ class Llama4Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=int(rope_theta),
- rope_scaling=rope_scaling if rope_scaling != "default" else None,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
)
if not self.nope
@@ -331,8 +327,6 @@ class Llama4DecoderLayer(nn.Module):
self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
- rope_theta = config.rope_theta
- rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
@@ -340,8 +334,6 @@ class Llama4DecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
@@ -738,6 +730,9 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
assert isinstance(layer, Llama4DecoderLayer)
if isinstance(layer.feed_forward, Llama4MoE):
# Pick last one layer since the first ones may be dense layers.
@@ -774,6 +769,9 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
if isinstance(layer.feed_forward, Llama4MoE):
moe = layer.feed_forward
moe.n_local_physical_experts = num_local_physical_experts
diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py
index 660c8f1bb5226..0146b30579287 100644
--- a/vllm/model_executor/models/llama4_eagle.py
+++ b/vllm/model_executor/models/llama4_eagle.py
@@ -23,7 +23,6 @@ import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
-from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -127,17 +126,11 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
- # if PP disabled then draft will share embed with target
- if get_pp_group().world_size == 1 and "embed_tokens." in name:
- continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for name in params_dict:
- # if PP disabled then draft will share embed with target
- if get_pp_group().world_size == 1 and "embed_tokens." in name:
- continue
assert name in loaded_params, f"{name} is not loaded!"
return loaded_params
diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py
index 90ab5c50361b6..05cb456e7776e 100644
--- a/vllm/model_executor/models/llama_eagle.py
+++ b/vllm/model_executor/models/llama_eagle.py
@@ -9,7 +9,6 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
-from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -155,10 +154,6 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
- # if PP disabled then draft will share embed with target
- if get_pp_group().world_size == 1 and "embed_tokens." in name:
- continue
-
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 75c671311b491..7a57644db1b13 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -23,7 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
-from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import (
@@ -121,13 +120,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
@support_torch_compile(
- # torch.compile is disabled for multimodal EAGLE3 models due to constraint
- # violations with dynamic shapes during tensor concatenation operations.
- # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
- # Non-multimodal EAGLE3 models can still use torch.compile safely.
- enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
- vllm_config.model_config
- ),
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "hidden_states": 0,
+ "input_embeds": 0,
+ }
)
class LlamaModel(nn.Module):
def __init__(
@@ -144,6 +142,12 @@ class LlamaModel(nn.Module):
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
+ eagle_config = getattr(self.config, "eagle_config", None)
+ if eagle_config is not None and "use_aux_hidden_state" in eagle_config:
+ self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
+ else:
+ self.use_aux_hidden_state = True
+
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@@ -163,20 +167,20 @@ class LlamaModel(nn.Module):
for layer_idx in range(self.config.num_hidden_layers)
]
)
- if hasattr(self.config, "target_hidden_size"):
- fc_input_size = self.config.target_hidden_size * 3
- else:
- fc_input_size = self.config.hidden_size * 3
- self.fc = ReplicatedLinear(
- input_size=fc_input_size,
- output_size=self.config.hidden_size,
- bias=False,
- params_dtype=vllm_config.model_config.dtype,
- quant_config=self.quant_config,
- prefix=maybe_prefix(prefix, "fc"),
- return_bias=False,
- )
-
+ if self.use_aux_hidden_state:
+ if hasattr(self.config, "target_hidden_size"):
+ fc_input_size = self.config.target_hidden_size * 3
+ else:
+ fc_input_size = self.config.hidden_size * 3
+ self.fc = ReplicatedLinear(
+ input_size=fc_input_size,
+ output_size=self.config.hidden_size,
+ bias=False,
+ params_dtype=vllm_config.model_config.dtype,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "fc"),
+ return_bias=False,
+ )
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@@ -334,6 +338,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
+ if not self.model.use_aux_hidden_state:
+ return hidden_states
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
@@ -359,6 +365,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
+ if not self.model.use_aux_hidden_state:
+ skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py
index 5de10e7086830..c5441283f9711 100644
--- a/vllm/model_executor/models/longcat_flash.py
+++ b/vllm/model_executor/models/longcat_flash.py
@@ -108,8 +108,7 @@ class FlashConfig(PretrainedConfig):
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
- rope_theta=1000000.0,
- rope_scaling=None,
+ rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
mla_scale_q_lora=False,
@@ -119,7 +118,7 @@ class FlashConfig(PretrainedConfig):
router_dtype="float32",
router_bias=False,
topk_method=None,
- routed_scaling_factor=None,
+ routed_scaling_factor=1.0,
zero_expert_num=0,
zero_expert_type=None,
nextn_use_scmoe=False,
@@ -162,8 +161,13 @@ class FlashConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 1000000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mla_scale_q_lora = mla_scale_q_lora
@@ -336,15 +340,7 @@ class FlashDecoderLayer(nn.Module):
super().__init__()
self.layer_idx = int(prefix.split(sep=".")[-1])
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
# Dual attention structure
self.self_attn = nn.ModuleList(
@@ -361,8 +357,6 @@ class FlashDecoderLayer(nn.Module):
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=None
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 914b097fe199e..04923833065f3 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -230,8 +230,7 @@ class MiniCPMAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -257,7 +256,6 @@ class MiniCPMAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -281,8 +279,7 @@ class MiniCPMAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
@@ -324,8 +321,6 @@ class MiniCPMDecoderLayer(nn.Module):
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size
- self.rope_theta = getattr(config, "rope_theta", 10000)
- self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block()
@@ -339,8 +334,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
- rope_theta=self.rope_theta,
- rope_scaling=self.rope_scaling,
+ rope_parameters=self.config.rope_parameters,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py
index d3b6966ee3a7f..2d775219fc972 100644
--- a/vllm/model_executor/models/minicpm3.py
+++ b/vllm/model_executor/models/minicpm3.py
@@ -25,8 +25,6 @@
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
-from typing import Any
-
import torch
from torch import nn
from transformers import PretrainedConfig
@@ -62,8 +60,6 @@ class MiniCPM3Attention(nn.Module):
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -84,7 +80,6 @@ class MiniCPM3Attention(nn.Module):
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.q_a_proj = ReplicatedLinear(
@@ -127,8 +122,7 @@ class MiniCPM3Attention(nn.Module):
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_local_heads,
@@ -204,8 +198,6 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
v_head_dim=self.config.v_head_dim,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
- rope_theta=self.rope_theta,
- rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py
index d0cdb70aa8574..e6bccfcac4f1a 100644
--- a/vllm/model_executor/models/minicpm_eagle.py
+++ b/vllm/model_executor/models/minicpm_eagle.py
@@ -69,8 +69,6 @@ class EagleMiniCPMDecoderLayer(nn.Module):
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size
- self.rope_theta = getattr(config, "rope_theta", 10000)
- self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.prefix = prefix
self._init_attn_block()
@@ -84,8 +82,7 @@ class EagleMiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
- rope_theta=self.rope_theta,
- rope_scaling=self.rope_scaling,
+ rope_parameters=self.config.rope_parameters,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py
index 49d2f2d261969..4955c68c0cda8 100644
--- a/vllm/model_executor/models/minimax_m2.py
+++ b/vllm/model_executor/models/minimax_m2.py
@@ -149,8 +149,7 @@ class MiniMaxM2Attention(nn.Module):
num_heads: int,
num_kv_heads: int,
rotary_dim: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
attn_window_size: int | None = None,
max_position_embeddings: int = 8192,
head_dim: int | None = None,
@@ -180,7 +179,6 @@ class MiniMaxM2Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -205,8 +203,7 @@ class MiniMaxM2Attention(nn.Module):
self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -252,8 +249,6 @@ class MiniMaxM2DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
max_position_embeddings = max(
@@ -269,8 +264,7 @@ class MiniMaxM2DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rotary_dim=config.rotary_dim,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py
index bf1ecc822756d..50f7396e2de60 100644
--- a/vllm/model_executor/models/minimax_text_01.py
+++ b/vllm/model_executor/models/minimax_text_01.py
@@ -188,7 +188,7 @@ class MiniMaxText01Attention(nn.Module):
num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
+ rope_parameters: dict | None = None,
sliding_window: int | None = None,
quant_config: QuantizationConfig | None = None,
layer_idx: int = None,
@@ -214,7 +214,6 @@ class MiniMaxText01Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.prefix = prefix
@@ -247,7 +246,7 @@ class MiniMaxText01Attention(nn.Module):
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
- base=int(rope_theta),
+ rope_parameters=rope_parameters,
is_neox_style=True,
dtype=torch.float32,
)
@@ -287,8 +286,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.expert_num = expert_num
- rope_theta = getattr(config, "rope_theta", 10000)
-
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
@@ -328,7 +325,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
else head_dim,
num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings,
- rope_theta=rope_theta,
+ rope_parameters=config.rope_parameters,
sliding_window=config.sliding_window,
quant_config=quant_config,
layer_idx=self._ilayer,
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index d7a1cb82fb4fb..0a9c3f136964e 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -51,7 +51,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
- DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@@ -161,7 +160,6 @@ class MixtralAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@@ -189,7 +187,6 @@ class MixtralAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -211,7 +208,7 @@ class MixtralAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=int(self.rope_theta),
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
@@ -248,15 +245,12 @@ class MixtralDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
@@ -306,23 +300,18 @@ class MixtralModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
+
parallel_config = vllm_config.parallel_config
self.config = config
self.quant_config = quant_config
- lora_vocab = (
- (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
- if lora_config
- else 0
- )
- self.vocab_size = config.vocab_size + lora_vocab
+
+ self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
- org_num_embeddings=config.vocab_size,
)
self.enable_eplb = parallel_config.enable_eplb
@@ -513,34 +502,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
+
self.config = config
- self.lora_config = lora_config
+
self.quant_config = quant_config
self.model = MixtralModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
- self.unpadded_vocab_size = config.vocab_size
- if lora_config:
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
self.lm_head = ParallelLMHead(
- self.unpadded_vocab_size,
+ config.vocab_size,
config.hidden_size,
- org_num_embeddings=config.vocab_size,
- padding_size=DEFAULT_VOCAB_PADDING_SIZE
- # We need bigger padding if using lora for kernel
- # compatibility
- if not lora_config
- else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
- self.logits_processor = LogitsProcessor(
- self.unpadded_vocab_size, config.vocab_size
- )
+ self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py
index e25a104d822a7..286859d188d34 100644
--- a/vllm/model_executor/models/mllama4.py
+++ b/vllm/model_executor/models/mllama4.py
@@ -292,13 +292,17 @@ class Llama4VisionAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
+ rope_parameters = {
+ "rope_type": "mllama4",
+ "rope_theta": config.rope_parameters["rope_theta"],
+ }
+
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches
max_position=(config.image_size // config.patch_size) ** 2,
- base=config.rope_theta,
- rope_scaling={"rope_type": "mllama4"},
+ rope_parameters=rope_parameters,
is_neox_style=False,
dtype=torch.complex64, # important
)
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index ab83a271e30a0..dc06938d5d6e1 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -410,7 +410,6 @@ class MolmoAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
@@ -437,7 +436,7 @@ class MolmoAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=config.rope_parameters,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py
index 2e3e6dc166ad8..63ea6b259a71d 100644
--- a/vllm/model_executor/models/moonvit.py
+++ b/vllm/model_executor/models/moonvit.py
@@ -56,10 +56,13 @@ from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
+from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
+elif current_platform.is_xpu():
+ from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
@@ -106,10 +109,10 @@ def multihead_attention(
q,
k,
v,
- q_cu_seqlens,
- k_cu_seqlens,
- max_seqlen_q,
- max_seqlen_k,
+ cu_seqlens_q=q_cu_seqlens,
+ cu_seqlens_k=k_cu_seqlens,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
@@ -291,7 +294,12 @@ class Rope2DPosEmb(nn.Module):
"""
def __init__(
- self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
+ self,
+ dim: int,
+ max_height: int,
+ max_width: int,
+ theta_base=10000,
+ device=current_platform.device_type,
):
super().__init__()
self.dim = dim
@@ -437,7 +445,7 @@ class MoonVitEncoderLayer(nn.Module):
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
# use fa2 in vllm by default
- if is_flash_attn_2_available():
+ if is_flash_attn_2_available() or current_platform.is_xpu():
self.attn_implementation = "flash_attention_2"
self.norm0 = nn.LayerNorm(hidden_dim)
diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py
index 92dcf5ea57008..c3337bd1ea699 100644
--- a/vllm/model_executor/models/nemotron.py
+++ b/vllm/model_executor/models/nemotron.py
@@ -26,7 +26,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -150,8 +149,6 @@ class NemotronAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -181,7 +178,6 @@ class NemotronAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.max_position_embeddings = max_position_embeddings
@@ -206,8 +202,7 @@ class NemotronAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(
@@ -243,14 +238,6 @@ class NemotronDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -264,8 +251,6 @@ class NemotronDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py
index b839206a3094d..2eebe38051cbd 100644
--- a/vllm/model_executor/models/nemotron_nas.py
+++ b/vllm/model_executor/models/nemotron_nas.py
@@ -26,7 +26,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import Any
import torch
from torch import nn
@@ -82,8 +81,6 @@ class DeciLMAttention(LlamaAttention):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -97,8 +94,6 @@ class DeciLMAttention(LlamaAttention):
hidden_size,
num_heads,
num_kv_heads,
- rope_theta,
- rope_scaling,
max_position_embeddings,
quant_config,
bias,
@@ -111,7 +106,6 @@ class DeciLMAttention(LlamaAttention):
def _init_rotary_emb(
self,
config,
- rope_scaling: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
# Enables YARN for Mistral and LLaMA4 derivatives.
@@ -126,8 +120,7 @@ class DeciLMAttention(LlamaAttention):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
@@ -148,14 +141,6 @@ class DeciLMDecoderLayer(nn.Module):
self._is_no_op_ffn = block_config.ffn.no_op
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -176,8 +161,6 @@ class DeciLMDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=num_kv_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 487e3f671a455..bd8a8e317544f 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -87,7 +87,6 @@ class OlmoAttention(nn.Module):
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
self.clip_qkv = config.clip_qkv
# Attention input projection. Projects x -> (q, k, v)
@@ -105,7 +104,7 @@ class OlmoAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=config.rope_parameters,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py
index 045582c889ee4..f0f6b2f6b3e6d 100644
--- a/vllm/model_executor/models/olmo2.py
+++ b/vllm/model_executor/models/olmo2.py
@@ -99,7 +99,6 @@ class Olmo2Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = self.config.max_position_embeddings
- self.rope_theta = self.config.rope_theta
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
@@ -139,15 +138,17 @@ class Olmo2Attention(nn.Module):
prefix=f"{prefix}.attn",
)
- # Rotary embeddings. Rope scaling is only applied on full attention
- # layers.
- self.rope_scaling = self.config.rope_scaling if sliding_window is None else None
+ # Rotary embeddings. Rope scaling is only applied on full attention layers.
+ if sliding_window is None:
+ rope_parameters = self.config.rope_parameters
+ else:
+ rope_theta = self.config.rope_parameters["rope_theta"]
+ rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta, # type: ignore
- rope_scaling=self.rope_scaling,
+ rope_parameters=rope_parameters,
)
# Attention output projection.
diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py
index 499eb05de76e4..c39e338d72e22 100644
--- a/vllm/model_executor/models/olmoe.py
+++ b/vllm/model_executor/models/olmoe.py
@@ -123,8 +123,6 @@ class OlmoeAttention(nn.Module):
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
num_heads = config.num_attention_heads
@@ -148,7 +146,6 @@ class OlmoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -176,8 +173,7 @@ class OlmoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/opencua.py b/vllm/model_executor/models/opencua.py
new file mode 100644
index 0000000000000..121bf896fa6ba
--- /dev/null
+++ b/vllm/model_executor/models/opencua.py
@@ -0,0 +1,271 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+#
+# Adapted from Qwen2.5-VL implementation
+# Copyright 2025 The vLLM team.
+# Copyright 2025 XLANG Lab, The University of Hong Kong
+
+"""Inference-only OpenCUA-7B model compatible with HuggingFace weights."""
+
+from collections.abc import Mapping, Sequence
+from typing import Any
+
+import torch
+import torch.nn as nn
+from transformers import BatchFeature
+from transformers.models.qwen2_vl import (
+ Qwen2VLImageProcessor,
+ Qwen2VLProcessor,
+ Qwen2VLVideoProcessor,
+)
+
+from vllm.config import VllmConfig
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalFieldConfig,
+ MultiModalKwargs,
+)
+from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+from .qwen2_5_vl import (
+ Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
+)
+from .qwen2_5_vl import (
+ Qwen2_5_VLForConditionalGeneration,
+)
+from .qwen2_vl import (
+ Qwen2VLDummyInputsBuilder,
+ Qwen2VLMultiModalDataParser,
+ Qwen2VLProcessingInfo,
+ _create_qwen2vl_field_factory,
+)
+from .utils import (
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+
+class OpenCUAProcessingInfo(Qwen2VLProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config()
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_hf_processor(self, **kwargs: object):
+ """Load OpenCUA processor."""
+ tokenizer = self.get_tokenizer()
+ vision_config = self.ctx.get_hf_image_processor_config()
+ return OpenCUAProcessor(
+ vision_config=vision_config,
+ tokenizer=tokenizer,
+ **kwargs,
+ )
+
+
+class OpenCUAProcessor(Qwen2VLProcessor):
+ def check_argument_for_proper_class(self, attribute_name: str, arg: object) -> None:
+ if attribute_name == "tokenizer":
+ return
+ return super().check_argument_for_proper_class(attribute_name, arg)
+
+ def __init__(
+ self,
+ vision_config: dict,
+ tokenizer: AnyTokenizer,
+ **kwargs,
+ ):
+ image_processor = Qwen2VLImageProcessor(**vision_config)
+ video_processor = Qwen2VLVideoProcessor(**vision_config)
+ chat_template = kwargs.pop("chat_template", None)
+
+ super().__init__(
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ video_processor=video_processor,
+ chat_template=chat_template,
+ **kwargs,
+ )
+
+ self.image_token = "<|media_placeholder|>"
+
+ def __call__(
+ self,
+ text=None,
+ images=None,
+ return_tensors=None,
+ **kwargs,
+ ):
+ if text is not None:
+ if not isinstance(text, list):
+ text = [text]
+ text_inputs = self.tokenizer(text, **kwargs)
+ else:
+ text_inputs = {}
+
+ image_inputs = {}
+ if images is not None:
+ if not isinstance(images, list):
+ images = [images]
+ if len(images) > 0:
+ image_inputs = self.image_processor(
+ images, return_tensors=return_tensors or "pt"
+ )
+
+ combined_inputs = {**text_inputs, **image_inputs}
+
+ return BatchFeature(combined_inputs, tensor_type=return_tensors)
+
+
+class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo]):
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return Qwen2VLMultiModalDataParser(
+ self.info.get_hf_config().vision_config.spatial_merge_size
+ )
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _create_qwen2vl_field_factory(
+ self.info.get_hf_config().vision_config.spatial_merge_size
+ )(hf_inputs)
+
+ def _hf_processor_applies_updates(
+ self,
+ prompt_text: str,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ tokenization_kwargs: Mapping[str, object],
+ ) -> bool:
+ """vLLM이 prompt 업데이트를 처리하도록 False 반환."""
+ return False
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+ hf_config = self.info.get_hf_config()
+
+ image_token_str = getattr(hf_processor, "image_token", "<|media_placeholder|>")
+ image_token_id = vocab.get(
+ image_token_str,
+ getattr(hf_config, "media_placeholder_token_id", 151664),
+ )
+
+ merge_length = image_processor.merge_size**2
+
+ def get_replacement_opencua(item_idx: int):
+ out_item = out_mm_kwargs["image"][item_idx]
+ grid_thw = out_item["image_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+
+ num_tokens = int(grid_thw.prod()) // merge_length
+ return [image_token_id] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_token_id],
+ replacement=get_replacement_opencua,
+ )
+ ]
+
+
+class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ image_token = "<|media_placeholder|>"
+
+ return image_token * num_images
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ OpenCUAMultiModalProcessor,
+ info=OpenCUAProcessingInfo,
+ dummy_inputs=OpenCUADummyInputsBuilder,
+)
+class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
+ merge_by_field_config = True
+ multimodal_cpu_fields = {"image_grid_thw"}
+
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "model.language_model.": "language_model.model.",
+ "model.visual.": "visual.",
+ "vision_tower.": "visual.",
+ "lm_head.": "language_model.lm_head.",
+ "model.": "language_model.model.",
+ }
+ )
+
+ supports_encoder_tp_data = True
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return "<|media_placeholder|>"
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ nn.Module.__init__(self)
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
+ self.config = config
+ self.vllm_config = vllm_config
+ self.multimodal_config = multimodal_config
+ self.quant_config = quant_config
+ self.is_multimodal_pruning_enabled = (
+ multimodal_config.is_multimodal_pruning_enabled()
+ )
+
+ if multimodal_config.get_limit_per_prompt("image"):
+ attn_backend_override = (
+ multimodal_config.mm_encoder_attn_backend
+ if multimodal_config is not None
+ else None
+ )
+ self.visual = OpenCUAVisionTransformer(
+ vision_config=config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ use_data_parallel=self.use_data_parallel,
+ attn_backend_override=attn_backend_override,
+ )
+ else:
+ self.visual = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Qwen2ForCausalLM"],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py
index f46fd3c7f319d..0486032645ad2 100644
--- a/vllm/model_executor/models/openpangu.py
+++ b/vllm/model_executor/models/openpangu.py
@@ -86,6 +86,7 @@ from vllm.v1.kv_cache_interface import (
FullSinkAttentionSpec,
KVCacheSpec,
)
+from vllm.transformers_utils.config import set_default_rope_theta
def check_ffn_act_fn(act_fn: str):
@@ -410,7 +411,6 @@ class OpenPanguMLAAttention(nn.Module):
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
- rope_theta: float = 10000,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -425,8 +425,6 @@ class OpenPanguMLAAttention(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
- self.rope_theta = rope_theta
-
self.tp_size = get_tensor_model_parallel_world_size()
if num_heads % self.tp_size != 0:
raise ValueError(
@@ -490,7 +488,9 @@ class OpenPanguMLAAttention(nn.Module):
)
# TODO: remove hard coding
- rope_scaling = {
+ set_default_rope_theta(config, default_theta=10000)
+ rope_parameters = {
+ "rope_theta": config.rope_parameters["rope_theta"],
"beta_fast": 32,
"beta_slow": 1,
"factor": 1,
@@ -504,8 +504,7 @@ class OpenPanguMLAAttention(nn.Module):
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
is_neox_style=False,
)
@@ -558,8 +557,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -605,7 +602,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -626,9 +622,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
- self._init_rotary_emb(
- config, rope_scaling=rope_scaling, quant_config=quant_config
- )
+ self._init_rotary_emb(config, quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
@@ -672,7 +666,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
def _init_rotary_emb(
self,
config: PretrainedConfig,
- rope_scaling: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = True
@@ -684,8 +677,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
)
@@ -974,7 +966,6 @@ class OpenPanguDecoderLayer(nn.Module):
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
layer_idx = int(prefix.split(sep=".")[-1])
@@ -1001,7 +992,6 @@ class OpenPanguDecoderLayer(nn.Module):
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
- rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
@@ -1060,8 +1050,6 @@ class OpenPanguDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=getattr(config, "rope_scaling", None),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
@@ -1090,7 +1078,7 @@ class OpenPanguDecoderLayer(nn.Module):
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
- self.routed_scaling_factor = getattr(config, "routed_scaling_factor", None)
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.num_hidden_layers = config.num_hidden_layers
self.first_k_dense_replace = getattr(
config, "first_k_dense_replace", self.num_hidden_layers
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index 859cd2cecf897..b30be93ca726f 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -88,8 +88,7 @@ class OrionAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -115,7 +114,6 @@ class OrionAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -139,8 +137,7 @@ class OrionAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -175,15 +172,12 @@ class OrionDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = OrionAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py
index 9db6c317c26a8..63d2fff6ec8bc 100644
--- a/vllm/model_executor/models/ouro.py
+++ b/vllm/model_executor/models/ouro.py
@@ -112,10 +112,8 @@ class OuroAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
- rope_scaling: tuple | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: dict[str, Any] | None = None,
@@ -140,7 +138,6 @@ class OuroAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
# Get total_ut_steps from config, default to 4 if not specified
@@ -170,8 +167,7 @@ class OuroAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = nn.ModuleList()
@@ -226,9 +222,6 @@ class OuroDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 1000000)
- rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
@@ -244,10 +237,8 @@ class OuroDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py
index dee0c16ab0f63..74bb868492da9 100644
--- a/vllm/model_executor/models/paddleocr_vl.py
+++ b/vllm/model_executor/models/paddleocr_vl.py
@@ -38,7 +38,6 @@ from vllm.attention.layer import (
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
- vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -657,7 +656,6 @@ class SiglipAttention(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
- seqlens: torch.Tensor | None,
) -> torch.Tensor:
batch_size, _, _ = hidden_states.shape
@@ -703,10 +701,6 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- if seqlens is None:
- raise ValueError("xFormers attention backend requires seqlens tensor.")
- context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@@ -818,7 +812,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None,
max_seqlen: torch.Tensor | None,
- seqlens: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
@@ -828,7 +821,6 @@ class SiglipEncoderLayer(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
hidden_states = residual + hidden_states
@@ -870,7 +862,6 @@ class SiglipEncoder(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -943,14 +934,11 @@ class SiglipEncoder(nn.Module):
cu_seqlens = cu_seqlens.to(device=device)
max_seqlen = None
- seqlens = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
@@ -959,7 +947,6 @@ class SiglipEncoder(nn.Module):
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
return hidden_states
diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py
index 3bf6a1d9763d0..98963d52e4848 100644
--- a/vllm/model_executor/models/persimmon.py
+++ b/vllm/model_executor/models/persimmon.py
@@ -106,7 +106,6 @@ class PersimmonAttention(nn.Module):
self.num_heads = self.total_num_heads // tensor_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True
@@ -138,7 +137,7 @@ class PersimmonAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=self.rope_theta,
+ rope_parameters=config.rope_parameters,
partial_rotary_factor=self.partial_rotary_factor,
)
self.scaling = self.head_dim**-0.5
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index 8fee53c23fb4b..da476f621627b 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -115,16 +115,12 @@ class PhiAttention(nn.Module):
)
assert rotary_dim % 2 == 0
- # pylint: disable=C0301
- # Refer to:
- # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
- rope_theta = getattr(config, "rope_theta", 10000.0)
max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py
index 92fd858b608bc..8ffac95d93960 100644
--- a/vllm/model_executor/models/phimoe.py
+++ b/vllm/model_executor/models/phimoe.py
@@ -86,7 +86,7 @@ class PhiMoEConfig(PretrainedConfig):
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
- rope_theta=1e6,
+ rope_parameters=None,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
@@ -119,7 +119,9 @@ class PhiMoEConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
+ if rope_parameters is None:
+ rope_theta = kwargs.pop("rope_theta", 1e6)
+ rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
@@ -302,12 +304,11 @@ class PhiMoEAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
+ rope_parameters: dict,
head_dim: int | None = None,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
- rope_scaling: dict | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -332,8 +333,6 @@ class PhiMoEAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -355,9 +354,8 @@ class PhiMoEAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=int(self.rope_theta),
+ rope_parameters=rope_parameters,
is_neox_style=True,
- rope_scaling=self.rope_scaling,
)
self.attn = Attention(
self.num_heads,
@@ -393,7 +391,6 @@ class PhiMoEDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = PhiMoEAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -402,10 +399,9 @@ class PhiMoEDecoderLayer(nn.Module):
head_dim=getattr(
config, "head_dim", self.hidden_size // config.num_attention_heads
),
- rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=config.rope_scaling,
+ rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
)
self.block_sparse_moe = PhiMoE(
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index 8a034fd72b02a..6011d93a795d1 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -74,6 +74,7 @@ from .vision import (
)
try:
+ # Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):
diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py
index 0c87f5000ff45..472de5590dcf8 100644
--- a/vllm/model_executor/models/plamo2.py
+++ b/vllm/model_executor/models/plamo2.py
@@ -4,10 +4,6 @@
from collections.abc import Iterable
from itertools import islice
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@@ -294,7 +290,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
- # NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_d, hidden_states_p = torch.split(
@@ -467,11 +462,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba2"
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
-
- return Mamba2AttentionBackend
-
def plamo2_mamba_mixer(
hidden_states: torch.Tensor,
@@ -576,10 +566,6 @@ class Plamo2AttentionMixer(nn.Module):
prefix=f"{prefix}.o_proj",
)
- self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
- self.rope_scaling = (
- config.rope_scaling if hasattr(config, "rope_scaling") else None
- )
max_position = config.max_position_embeddings
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
vllm_config.model_config.max_model_len, int
@@ -590,8 +576,7 @@ class Plamo2AttentionMixer(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=self.rope_theta,
- rope_scaling=self.rope_scaling,
+ rope_parameters=config.rope_parameters,
)
self.q_norm = RMSNorm(config.hidden_size_per_head, eps=config.rms_norm_eps)
self.q_norm.weight = torch.nn.Parameter(
diff --git a/vllm/model_executor/models/plamo3.py b/vllm/model_executor/models/plamo3.py
new file mode 100644
index 0000000000000..4aeb9d432dcc6
--- /dev/null
+++ b/vllm/model_executor/models/plamo3.py
@@ -0,0 +1,441 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only PLaMo3 model."""
+
+from collections.abc import Iterable
+from itertools import islice
+from typing import Any
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention.layer import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.distributed.parallel_state import get_pp_group
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE,
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ LoaderFunction,
+ composed_weight_loader,
+ default_weight_loader,
+)
+from vllm.model_executor.models.interfaces import SupportsPP
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ extract_layer_index,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+from vllm.model_executor.utils import set_weight_attrs
+from vllm.sequence import IntermediateTensors
+
+
+# Only used for type hinting.
+class Plamo3Config(PretrainedConfig): # type: ignore
+ model_type: str = "plamo3"
+
+ hidden_size: int
+ num_hidden_layers: int
+ rms_norm_eps: float
+ # Attention
+ num_attention_heads: int
+ head_dim: int
+ num_key_value_heads: int
+ # vllm rename `sliding_window` attr to `interleaved_sliding_window`
+ # if `sliding_window` is list
+ interleaved_sliding_window: list[int | None]
+ sliding_window_pattern: int
+ rope_parameters: dict[str, Any]
+ rope_local_theta: int
+ # MLP
+ intermediate_size: int
+ # Tokenizer
+ vocab_size: int
+
+
+def rms_norm_weight_loader(offset: float) -> LoaderFunction:
+ return composed_weight_loader(
+ default_weight_loader,
+ lambda x: x + offset,
+ )
+
+
+class DenseMLP(nn.Module):
+ def __init__(
+ self,
+ config: Plamo3Config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_up_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.intermediate_size] * 2,
+ bias=False,
+ prefix=f"{prefix}.gate_up_proj",
+ quant_config=quant_config,
+ return_bias=False,
+ )
+ self.act = SiluAndMul()
+ self.down_proj = RowParallelLinear(
+ self.intermediate_size,
+ self.hidden_size,
+ bias=False,
+ prefix=f"{prefix}.down_proj",
+ quant_config=quant_config,
+ return_bias=False,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ h = self.gate_up_proj(hidden_states)
+ h = self.act(h)
+ return self.down_proj(h)
+
+
+class Plamo3AttentionMixer(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+
+ self.hidden_size = config.hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = config.num_attention_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = config.num_key_value_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = config.head_dim
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ config.hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ config.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ layer_idx = extract_layer_index(prefix)
+ layer_type = config.layer_types[layer_idx]
+ is_sliding = layer_type == "sliding_attention"
+
+ # Initialize the rotary embedding.
+ if layer_type in config.rope_parameters:
+ # Transformers v5 rope config.
+ rope_parameters = config.rope_parameters[layer_type]
+ else:
+ # Transformers v4 rope config.
+ # Global attention. Use the values in config.json.
+ rope_parameters = config.rope_parameters
+ # Local attention. Override the values in config.json.
+ if is_sliding:
+ rope_parameters = dict(
+ rope_type="default", rope_theta=config.rope_local_theta
+ )
+ max_position = config.max_position_embeddings
+ if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
+ vllm_config.model_config.max_model_len, int
+ ):
+ max_position = min(max_position, vllm_config.model_config.max_model_len)
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ rope_parameters=rope_parameters,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.q_norm.weight, {"weight_loader": rms_norm_weight_loader(offset=1.0)}
+ )
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.k_norm.weight, {"weight_loader": rms_norm_weight_loader(offset=1.0)}
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=vllm_config.cache_config,
+ per_layer_sliding_window=config.interleaved_sliding_window[layer_idx],
+ prefix=f"{prefix}.attn",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ q_shape = q.shape
+ q = q.reshape(q_shape[:-1] + (q_shape[-1] // self.head_dim, self.head_dim))
+ q = self.q_norm.forward_native(q).reshape(q_shape)
+ k_shape = k.shape
+ k = k.reshape(k_shape[:-1] + (k_shape[-1] // self.head_dim, self.head_dim))
+ k = self.k_norm.forward_native(k).reshape(k_shape)
+
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Plamo3DecoderLayer(nn.Module):
+ def __init__(
+ self, vllm_config: VllmConfig, prefix: str = "", **kwargs: Any
+ ) -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+
+ self.mixer = Plamo3AttentionMixer(
+ vllm_config=vllm_config,
+ prefix=f"{prefix}.mixer",
+ )
+
+ self.mlp = DenseMLP(
+ config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
+ )
+ self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.pre_mixer_norm.weight,
+ {"weight_loader": rms_norm_weight_loader(offset=1.0)},
+ )
+ self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.post_mixer_norm.weight,
+ {"weight_loader": rms_norm_weight_loader(offset=1.0 / 5)},
+ )
+ self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.pre_mlp_norm.weight,
+ {"weight_loader": rms_norm_weight_loader(offset=1.0)},
+ )
+ self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.post_mlp_norm.weight,
+ {"weight_loader": rms_norm_weight_loader(offset=1.0 / (5**1.5))},
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ **kwargs: Any,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.pre_mixer_norm(hidden_states)
+ else:
+ hidden_states, residual = self.pre_mixer_norm(hidden_states, residual)
+
+ hidden_states = self.mixer(
+ positions=positions, hidden_states=hidden_states, residual=residual
+ )
+ hidden_states = self.post_mixer_norm(hidden_states)
+ # Fully Connected
+ hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_mlp_norm(hidden_states)
+ return hidden_states, residual
+
+
+class Plamo3Decoder(torch.nn.Module):
+ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ num_hidden_layers,
+ lambda prefix: Plamo3DecoderLayer(vllm_config, prefix=prefix),
+ prefix=f"{prefix}.layers",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ for layer in islice(self.layers, self.start_layer, self.end_layer):
+ hidden_states, residual = layer(
+ positions=positions,
+ hidden_states=hidden_states,
+ residual=residual,
+ )
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Plamo3Model(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.org_vocab_size = config.vocab_size
+
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+ self.layers = Plamo3Decoder(vllm_config, prefix=f"{prefix}.layers")
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ set_weight_attrs(
+ self.norm.weight,
+ {"weight_loader": rms_norm_weight_loader(offset=1.0)},
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_input_ids(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ hidden_states, residual = self.layers(
+ positions=positions, hidden_states=hidden_states, residual=residual
+ )
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Plamo3ForCausalLM(nn.Module, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ self.config = vllm_config.model_config.hf_config
+ self.vllm_config = vllm_config
+ self.model_config = vllm_config.model_config
+ self.scheduler_config = vllm_config.scheduler_config
+
+ self.model = Plamo3Model(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+
+ self.vocab_size = self.config.vocab_size
+ self.unpadded_vocab_size = self.config.vocab_size
+
+ num_embeddings = ((self.vocab_size + 15) // 16) * 16
+ self.lm_head = ParallelLMHead(
+ num_embeddings,
+ self.config.hidden_size,
+ org_num_embeddings=self.config.vocab_size,
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+ prefix=f"{prefix}.lm_head",
+ )
+ if self.config.tie_word_embeddings:
+ self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
+
+ self.logits_processor = LogitsProcessor(
+ self.unpadded_vocab_size, self.config.vocab_size
+ )
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index 50a125c3f5973..c973e79170982 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -83,8 +83,7 @@ class QWenAttention(nn.Module):
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@@ -117,8 +116,7 @@ class QWenAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -153,14 +151,11 @@ class QWenBlock(nn.Module):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(
config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 1bbb969ce5aa3..5831ce0b3d64b 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -57,7 +57,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
-from vllm.transformers_utils.config import is_interleaved
+from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
@@ -114,11 +114,10 @@ class Qwen2Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
+ rope_parameters: dict[str, Any],
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
- rope_scaling: tuple | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: dict[str, Any] | None = None,
@@ -143,7 +142,6 @@ class Qwen2Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
@@ -167,8 +165,7 @@ class Qwen2Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
attn_cls = (
@@ -216,9 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 1000000)
- rope_scaling = getattr(config, "rope_scaling", None)
+ set_default_rope_theta(config, default_theta=1000000)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
@@ -237,10 +232,9 @@ class Qwen2DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
@@ -280,6 +274,38 @@ class Qwen2DecoderLayer(nn.Module):
return hidden_states, residual
+def qwen_2_model_invariants(
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+):
+ """Shape invariants for Qwen2Model Model, those are translated to
+ runtime assertions for unbacked dynamic shapes and are compiled away for
+ backed"""
+ # All these should be equal.
+ # input_ids.size()[0]
+ # positions.size()[-1]
+ # intermediate_tensors["hidden_states"].size()[0]
+ # inputs_embeds.size()[0]
+ torch._check(input_ids.size()[0] == positions.size()[-1])
+ if intermediate_tensors is not None:
+ torch._check(
+ input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
+ )
+
+ if inputs_embeds is not None:
+ torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
+
+ # Hidden dimensions should match (hidden_size)
+ # intermediate_tensors["hidden_states"].size()[1]
+ # inputs_embeds.size()[1]
+ if inputs_embeds is not None and intermediate_tensors is not None:
+ torch._check(
+ inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
+ )
+
+
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
@@ -288,7 +314,8 @@ class Qwen2DecoderLayer(nn.Module):
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
- }
+ },
+ shape_invariants=qwen_2_model_invariants,
)
class Qwen2Model(nn.Module):
def __init__(
diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py
index 262ea771d9cdf..7506ee8656fda 100644
--- a/vllm/model_executor/models/qwen2_5_omni_thinker.py
+++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py
@@ -23,7 +23,6 @@
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence
-from copy import copy
from functools import partial
from typing import Annotated, Any, Literal
@@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
- use_audio_in_video = False
- if "video" in mm_kwargs:
- video_items = [item for item in mm_kwargs["video"] if item is not None]
- # only check video items (if there are any)
- if video_items:
- use_audio_in_video = all(
- item["use_audio_in_video"].data for item in video_items
- )
-
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
@@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
- use_audio_in_video=use_audio_in_video,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
@@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
- use_audio_in_video=use_audio_in_video,
)
return prompt_ids, mm_placeholders
@@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
return mm_processed_data
- def _validate_mm_placeholders(
- self,
- mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
- mm_item_counts: Mapping[str, int],
- use_audio_in_video: bool = False,
- ) -> None:
- if use_audio_in_video:
- mm_item_counts = copy(mm_item_counts)
- if "video" in mm_item_counts:
- assert "audio" in mm_item_counts
- mm_item_counts["audio"] -= mm_item_counts["video"]
- super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
-
class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_audio_input(
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 2e4fd9645d88f..8c707c2561af1 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -46,7 +46,6 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
- vit_xformers_attn_wrapper,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -230,6 +229,9 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- hidden_size must match the hidden size of language model backbone.
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
+ - second_per_grid_ts: The video time interval (in seconds) for each
+ grid along the temporal dimension in the 3D position IDs. Returned
+ when `videos` is not `None`.
"""
type: Literal["video_embeds"]
@@ -244,6 +246,11 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
TensorShape("nv", 3),
]
+ second_per_grid_ts: Annotated[
+ torch.Tensor | None,
+ TensorShape("nv"),
+ ] = None
+
Qwen2_5_VLVideoInputs: TypeAlias = (
Qwen2_5_VLVideoPixelInputs | Qwen2_5_VLVideoEmbeddingInputs
@@ -367,7 +374,6 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -427,8 +433,6 @@ class Qwen2_5_VisionAttention(nn.Module):
v,
cu_seqlens,
)
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
return output
@@ -440,9 +444,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"cu_seqlens": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
- "seqlens": 0,
},
- mark_unbacked_dims={"seqlens": 0},
enable_if=should_torch_compile_mm_vit,
)
class Qwen2_5_VisionBlock(nn.Module):
@@ -493,7 +495,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -501,7 +502,6 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
@@ -641,7 +641,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
- base=10000.0,
is_neox_style=True,
)
@@ -663,7 +662,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -738,13 +736,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)
- cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
- cos_w = cos[pos_ids[:, 1]]
- sin_h = sin[pos_ids[:, 0]]
- sin_w = sin[pos_ids[:, 1]]
-
- cos_combined = torch.cat([cos_h, cos_w], dim=-1)
- sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ cos_combined = cos[pos_ids].flatten(1)
+ sin_combined = sin[pos_ids].flatten(1)
cos_combined = cos_combined.reshape(
cos_combined.shape[0] // self.spatial_merge_unit,
@@ -820,17 +813,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
@@ -895,10 +885,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
- max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
- max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
- cu_window_seqlens
- )
+ max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
@@ -925,11 +913,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
- seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
- seqlens_now = seqlens_window
hidden_states = blk(
hidden_states,
@@ -937,7 +923,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
- seqlens=seqlens_now,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block
@@ -1317,6 +1302,7 @@ class Qwen2_5_VLForConditionalGeneration(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
)
def _process_image_input(
@@ -1428,7 +1414,13 @@ class Qwen2_5_VLForConditionalGeneration(
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
- second_per_grid_ts = video_input["second_per_grid_ts"].long()
+ second_per_grid_ts = video_input.get("second_per_grid_ts")
+ if second_per_grid_ts is None:
+ raise ValueError(
+ "second_per_grid_ts is required when video_pruning_rate > 0 "
+ "is enabled for video inputs, including the video_embeds path."
+ )
+ second_per_grid_ts = second_per_grid_ts.long()
tokens_per_second = self.config.vision_config.tokens_per_second
video_embeds_out = []
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 2ff0d19df238c..6b97d0b2ca2e3 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -194,8 +194,7 @@ class Qwen2MoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
@@ -222,7 +221,6 @@ class Qwen2MoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
@@ -248,8 +246,7 @@ class Qwen2MoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
@@ -291,8 +288,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
@@ -301,8 +296,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 53df5972a8fe1..9d1d023aed172 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -29,6 +29,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -347,7 +348,6 @@ class Qwen2VisionAttention(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -383,7 +383,6 @@ class Qwen2VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, 3 * head * head_dim]
x, _ = self.qkv(x)
@@ -444,20 +443,6 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
-
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
-
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> s b (h d)"
- ).contiguous()
output, _ = self.proj(context_layer)
return output
@@ -508,7 +493,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
- seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -516,7 +500,6 @@ class Qwen2VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -643,7 +626,6 @@ class Qwen2VisionTransformer(nn.Module):
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
- base=10000.0,
is_neox_style=True,
)
@@ -724,27 +706,18 @@ class Qwen2VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
- cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
- cos_w = cos[pos_ids[:, 1]]
- sin_h = sin[pos_ids[:, 0]]
- sin_w = sin[pos_ids[:, 1]]
-
- cos_combined = torch.cat([cos_h, cos_w], dim=-1)
- sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ cos_combined = cos[pos_ids].flatten(1)
+ sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
- def compute_attn_mask_seqlen(
- self, cu_seqlens: torch.Tensor
- ) -> tuple[int | None, list[int] | None]:
- max_seqlen, seqlens = None, None
+ def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
+ max_seqlen = None
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -757,25 +730,27 @@ class Qwen2VisionTransformer(nn.Module):
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
- grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
+ grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
+ grid_thw = grid_thw.numpy()
# compute position embedding
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
# compute cu_seqlens
- cu_seqlens = torch.repeat_interleave(
- grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
- ).cumsum(dim=0, dtype=torch.int32)
- cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
- cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
+ cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ axis=0, dtype=np.int32
+ )
+ cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
+ cu_seqlens = torch.from_numpy(cu_seqlens)
# transformers
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
+ cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
for blk in self.blocks:
x = blk(
x,
@@ -783,7 +758,6 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
# adapter
diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py
index 8d7f22a33fe6c..93a629d81e8ff 100644
--- a/vllm/model_executor/models/qwen3.py
+++ b/vllm/model_executor/models/qwen3.py
@@ -42,6 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import set_default_rope_theta
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
@@ -57,14 +58,13 @@ class Qwen3Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
+ rope_parameters: dict,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
- rope_scaling: tuple | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: dict[str, Any] | None = None,
@@ -89,7 +89,6 @@ class Qwen3Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
@@ -113,8 +112,7 @@ class Qwen3Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
@@ -166,9 +164,7 @@ class Qwen3DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 1000000)
- rope_scaling = getattr(config, "rope_scaling", None)
+ set_default_rope_theta(config, default_theta=1000000)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
@@ -187,13 +183,12 @@ class Qwen3DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index 96751fee800bb..8ee3dd99e11db 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -216,8 +216,7 @@ class Qwen3MoeAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
@@ -247,7 +246,6 @@ class Qwen3MoeAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
@@ -273,8 +271,7 @@ class Qwen3MoeAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
@@ -326,8 +323,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
@@ -336,8 +331,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py
index 86508a7c64317..bfed64728305e 100644
--- a/vllm/model_executor/models/qwen3_next.py
+++ b/vllm/model_executor/models/qwen3_next.py
@@ -10,7 +10,7 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
-from vllm.attention import Attention, AttentionBackend, AttentionMetadata
+from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
@@ -216,12 +216,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
- return "linear_attention"
-
- def get_attn_backend(self) -> type["AttentionBackend"]:
- from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
-
- return GDNAttentionBackend
+ return "gdn_attention"
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
@@ -753,8 +748,7 @@ class Qwen3NextAttention(nn.Module):
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
- base=config.rope_theta,
- rope_scaling=config.rope_scaling,
+ rope_parameters=config.rope_parameters,
partial_rotary_factor=config.partial_rotary_factor,
dual_chunk_attention_config=self.dual_chunk_attention_config,
)
@@ -1154,8 +1148,8 @@ class QwenNextMixtureOfExperts(MixtureOfExperts):
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
- if example_moe is None:
- raise RuntimeError("No Qwen3Next layer found in the model.layers.")
+ if example_moe is None:
+ raise RuntimeError("No Qwen3Next layer found in the model.layers.")
# Set MoE hyperparameters
self.num_moe_layers = len(self.moe_layers)
diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
index 8274b92138f78..f5f88f66eff91 100755
--- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py
+++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
@@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
- BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
+ PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
@@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
- Qwen2_5OmniThinkerProcessingInfo,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
@@ -224,7 +223,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -232,7 +230,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -338,7 +335,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
- base=10000.0,
is_neox_style=True,
)
@@ -428,13 +424,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
- cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
- cos_w = cos[pos_ids[:, 1]]
- sin_h = sin[pos_ids[:, 0]]
- sin_w = sin[pos_ids[:, 1]]
-
- cos_combined = torch.cat([cos_h, cos_w], dim=-1)
- sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ cos_combined = cos[pos_ids].flatten(1)
+ sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
@@ -506,14 +497,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -539,7 +527,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@@ -551,7 +539,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if (
deepstack_visual_indexes is not None
@@ -819,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
else:
use_audio_in_video = False
- if use_audio_in_video and "video" in mm_item_counts:
- assert "audio" in mm_item_counts
- mm_item_counts["audio"] -= mm_item_counts["video"]
-
- # Special case with `use_audio_in_video=True`
- if use_audio_in_video:
- if is_update_applied:
- prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
- (
- prompt_ids,
- mm_placeholders,
- ) = self._apply_prompt_updates(
- prompt_ids,
- mm_prompt_updates,
- )
- self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
# normal case with `use_audio_in_video=False`
- elif is_update_applied:
+ if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
mm_prompt_updates,
@@ -846,10 +817,24 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts,
)
else:
- prompt_ids, mm_placeholders = self._apply_prompt_updates(
- prompt_ids,
- mm_prompt_updates,
- )
+ if use_audio_in_video and "audio" in mm_prompt_updates:
+ filtered_updates = {
+ k: v for k, v in mm_prompt_updates.items() if k != "audio"
+ }
+ prompt_ids, mm_placeholders = self._apply_prompt_updates(
+ prompt_ids,
+ filtered_updates,
+ )
+ # Derive audio placeholders from video placeholders
+ mm_placeholders = self._derive_audio_from_video_placeholders(
+ mm_placeholders, mm_prompt_updates
+ )
+ else:
+ prompt_ids, mm_placeholders = self._apply_prompt_updates(
+ prompt_ids,
+ mm_prompt_updates,
+ )
+
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
@@ -974,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
nonlocal audio_in_video_item_idx
- audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
+ audio_num_features = audio_output_lengths[
+ audio_in_video_item_idx + item_idx
+ ]
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
audio_in_video_item_idx += 1
@@ -983,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[item_idx]
else:
- video_second_per_grid_t = 1.0
+ video_second_per_grid_t = 2.0
- return self.get_updates_use_audio_in_video(
+ placeholder = self.get_updates_use_audio_in_video(
thinker_config=thinker_config,
audio_len=audio_num_features,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
+ return PromptUpdateDetails.select_token_id(
+ placeholder, embed_token_id=video_token_id
+ )
video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video
@@ -1016,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
),
]
- def _validate_mm_placeholders(
+ def _derive_audio_from_video_placeholders(
self,
- mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
- mm_item_counts: Mapping[str, int],
- ) -> None:
- BaseMultiModalProcessor[
- Qwen2_5OmniThinkerProcessingInfo
- ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)
+ placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
+ mm_prompt_updates: MultiModalPromptUpdates,
+ ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
+ """
+ Helper to derive audio placeholders from video placeholders when
+ use_audio_in_video=True.
+ """
+ if "video" not in placeholders:
+ return placeholders
+
+ # Validate audio and video counts match
+ num_videos = len(placeholders["video"])
+ num_audios = len(mm_prompt_updates.get("audio", []))
+ if num_audios != num_videos:
+ raise ValueError(
+ f"use_audio_in_video requires equal number of audio and video items, "
+ f"got {num_audios=}, {num_videos=}"
+ )
+
+ tokenizer = self.info.get_tokenizer()
+ processor = self.info.get_hf_processor()
+ audio_token_id = tokenizer.get_vocab()[processor.audio_token]
+
+ result_placeholders = dict(placeholders)
+ audio_placeholders = []
+
+ # Each video is paired with one audio
+ for video_idx, video_placeholder in enumerate(placeholders["video"]):
+ # Create is_embed mask selecting only audio tokens
+ audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
+
+ audio_placeholder = PlaceholderFeaturesInfo(
+ modality="audio",
+ item_idx=video_idx,
+ start_idx=video_placeholder.start_idx,
+ tokens=video_placeholder.tokens,
+ is_embed=audio_is_embed,
+ )
+ audio_placeholders.append(audio_placeholder)
+
+ result_placeholders["audio"] = audio_placeholders
+ return result_placeholders
def _get_raw_input_ids(
self,
@@ -1466,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
if not len(second_per_grid_ts) and len(video_grid_thw):
- second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
+ second_per_grid_ts = 2.0
+ second_per_grids = (
+ torch.ones(len(video_grid_thw), dtype=torch.float32)
+ * second_per_grid_ts
+ )
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py
index 99a4007ef7f23..4cd6fa14c32df 100644
--- a/vllm/model_executor/models/qwen3_vl.py
+++ b/vllm/model_executor/models/qwen3_vl.py
@@ -235,7 +235,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
- seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -243,7 +242,6 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
@@ -345,7 +343,6 @@ class Qwen3_VisionTransformer(nn.Module):
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
- base=10000.0,
is_neox_style=True,
)
@@ -392,7 +389,6 @@ class Qwen3_VisionTransformer(nn.Module):
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
- AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@@ -459,18 +455,13 @@ class Qwen3_VisionTransformer(nn.Module):
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
- pos_ids = torch.cat(pos_ids, dim=0)
+ pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
- cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
- cos_w = cos[pos_ids[:, 1]]
- sin_h = sin[pos_ids[:, 0]]
- sin_w = sin[pos_ids[:, 1]]
-
- cos_combined = torch.cat([cos_h, cos_w], dim=-1)
- sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ cos_combined = cos[pos_ids].flatten(1)
+ sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
@@ -537,17 +528,14 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
- seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- elif self.attn_backend == AttentionBackendEnum.XFORMERS:
- seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
- return max_seqlen, seqlens
+ return max_seqlen
def forward(
self,
@@ -559,27 +547,23 @@ class Qwen3_VisionTransformer(nn.Module):
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
- grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
+ grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
+ grid_thw = grid_thw.numpy()
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
- rotary_pos_emb_cos = rotary_pos_emb_cos.to(
- hidden_states.device, non_blocking=True
- )
- rotary_pos_emb_sin = rotary_pos_emb_sin.to(
- hidden_states.device, non_blocking=True
- )
- cu_seqlens = torch.repeat_interleave(
- grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
- ).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
- cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
+ cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ axis=0, dtype=np.int32
+ )
+ cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
+ cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
- max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
@@ -590,7 +574,6 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
- seqlens=seqlens,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py
index 5c3205faf9c2f..e2c129120b1a5 100644
--- a/vllm/model_executor/models/qwen3_vl_moe.py
+++ b/vllm/model_executor/models/qwen3_vl_moe.py
@@ -15,7 +15,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -29,7 +29,9 @@ from collections.abc import Callable, Iterable
from itertools import islice
import torch
-from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig
+from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
+ Qwen3VLMoeConfig,
+)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -44,7 +46,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
-from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
+from .interfaces import MixtureOfExperts
+from .qwen3_moe import (
+ Qwen3MoeForCausalLM,
+ Qwen3MoeModel,
+ Qwen3MoeSparseMoeBlock,
+)
from .qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLDummyInputsBuilder,
@@ -344,12 +351,56 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
)
+class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
+ for layer in self.language_model.model.layers:
+ if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
+ moe = layer.mlp
+ moe.n_local_physical_experts = num_local_physical_experts
+ moe.n_physical_experts = num_physical_experts
+ moe.n_redundant_experts = self.num_redundant_experts
+ moe.experts.update_expert_map()
+
+ def set_moe_parameters(self):
+ self.expert_weights = []
+
+ self.moe_layers = []
+ example_moe = None
+ for layer in self.language_model.model.layers:
+ if hasattr(layer, "mlp") and isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
+ example_moe = layer.mlp
+ self.moe_layers.append(layer.mlp.experts)
+
+ if example_moe is None:
+ raise RuntimeError("No Qwen3Moe layer found in the language_model.")
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = len(self.moe_layers)
+ self.num_expert_groups = 1
+ self.num_shared_experts = 0
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+
@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
-class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
+class Qwen3VLMoeForConditionalGeneration(
+ Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
+):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -413,3 +464,6 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
+
+ # Set MoE hyperparameters
+ self.set_moe_parameters()
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 8211321c39537..287cb6b59ba9e 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -158,6 +158,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
+ "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
@@ -287,8 +288,16 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration",
),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
+ "HunYuanVLForConditionalGeneration": (
+ "hunyuan_vision",
+ "HunYuanVLForConditionalGeneration",
+ ),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
+ "OpenCUAForConditionalGeneration": (
+ "opencua",
+ "OpenCUAForConditionalGeneration",
+ ),
"InternS1ForConditionalGeneration": (
"interns1",
"InternS1ForConditionalGeneration",
diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py
index bf211d28f1844..4744d8e44f390 100644
--- a/vllm/model_executor/models/seed_oss.py
+++ b/vllm/model_executor/models/seed_oss.py
@@ -54,6 +54,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import set_default_rope_theta
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
@@ -112,11 +113,10 @@ class SeedOssAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
head_dim: int,
+ rope_parameters: dict,
max_position: int = 4096 * 32,
- rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
- rope_scaling: tuple | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
@@ -140,7 +140,6 @@ class SeedOssAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -163,8 +162,7 @@ class SeedOssAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
- base=self.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -200,9 +198,7 @@ class SeedOssDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- # Requires transformers > 4.32.0
- rope_theta = getattr(config, "rope_theta", 1000000)
- rope_scaling = getattr(config, "rope_scaling", None)
+ set_default_rope_theta(config, default_theta=1000000)
# By default, SeedOss uses causal attention as it is a
# decoder-only model.
@@ -219,10 +215,9 @@ class SeedOssDecoderLayer(nn.Module):
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
- rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py
index 46f5e67d659ef..c185b45345bd5 100644
--- a/vllm/model_executor/models/siglip2navit.py
+++ b/vllm/model_executor/models/siglip2navit.py
@@ -191,7 +191,7 @@ def apply_rotary_pos_emb(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and not current_platform.is_xpu():
- from flash_attn.layers.rotary import apply_rotary_emb
+ from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
else:
diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py
index 4ec855f794446..7e9fc51036d2e 100644
--- a/vllm/model_executor/models/solar.py
+++ b/vllm/model_executor/models/solar.py
@@ -25,7 +25,6 @@
"""Inference-only Solar model compatible with HuggingFace weights."""
from collections.abc import Iterable
-from typing import Any
import torch
from torch import nn
@@ -111,8 +110,6 @@ class SolarAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@@ -142,7 +139,6 @@ class SolarAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@@ -166,8 +162,7 @@ class SolarAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
@@ -202,15 +197,6 @@ class SolarDecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
-
- if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
- rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
@@ -224,8 +210,6 @@ class SolarDecoderLayer(nn.Module):
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 06eb7201c1a89..a738fcbb4ee28 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -153,7 +153,7 @@ class StablelmAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.config.max_position_embeddings,
- base=self.config.rope_theta,
+ rope_parameters=self.config.rope_parameters,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 0f2942acd5006..1118fca3cac91 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -91,7 +91,6 @@ class Starcoder2Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
- self.rope_theta = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.use_bias = config.use_bias
@@ -115,7 +114,7 @@ class Starcoder2Attention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
- base=int(self.rope_theta),
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py
index 4fff356b29e28..3c377a2c539df 100644
--- a/vllm/model_executor/models/step3_text.py
+++ b/vllm/model_executor/models/step3_text.py
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.step3_vl import Step3TextConfig
from .interfaces import SupportsPP
from .utils import (
@@ -144,9 +145,8 @@ class Step3TextAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
norm_eps: float,
- rope_theta: int,
+ rope_parameters: dict[str, Any],
share_q_dim: int | None = None,
- rope_scaling: dict[str, Any] | None = None,
max_position_embedding: int = 8192,
head_dim: int = 256,
cache_config: CacheConfig | None = None,
@@ -198,8 +198,7 @@ class Step3TextAttention(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embedding,
- base=rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
)
scaling = self.head_dim**-0.5
self.attn = Attention(
@@ -227,15 +226,13 @@ class Step3TextAttention(nn.Module):
class Step3TextDecoderLayer(nn.Module):
def __init__(
self,
- config: ModelConfig,
+ config: Step3TextConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
- config = config.hf_config
self.hidden_size = config.hidden_size
- rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
@@ -247,8 +244,7 @@ class Step3TextDecoderLayer(nn.Module):
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
- rope_theta=config.rope_theta,
- rope_scaling=rope_scaling,
+ rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
)
@@ -338,7 +334,7 @@ class Step3TextModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step3TextDecoderLayer(
- config=vllm_config.model_config,
+ config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py
index 8a0bec9dff848..bebd7bcaa9249 100644
--- a/vllm/model_executor/models/teleflm.py
+++ b/vllm/model_executor/models/teleflm.py
@@ -74,5 +74,5 @@ class TeleFLMForCausalLM(LlamaForCausalLM):
self.output_mult = self.config.output_mult / self.mup_scale_factor
logit_scale = self.output_mult
self.logits_processor = LogitsProcessor(
- self.unpadded_vocab_size, self.config.vocab_size, logit_scale
+ self.config.vocab_size, scale=logit_scale
)
diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py
index 4973014c3d4ed..31db9d682bd40 100644
--- a/vllm/model_executor/models/transformers/moe.py
+++ b/vllm/model_executor/models/transformers/moe.py
@@ -256,7 +256,14 @@ class MoEMixin(MixtureOfExperts):
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
- if child_name == "experts" and isinstance(child_module, nn.ModuleList):
+ # Naive implementations will have experts as ModuleList
+ is_modulelist = isinstance(child_module, nn.ModuleList)
+ # Packed implementations will have experts as 3D tensors of shapes like:
+ # gate_up_proj = (num_experts, 2 * intermediate_size, hidden_size)
+ # down_proj = (num_experts, intermediate_size, hidden_size)
+ params = list(child_module.parameters())
+ is_3d = len(params) > 0 and all(p.ndim == 3 for p in params)
+ if child_name == "experts" and (is_modulelist or is_3d):
# Alias for readability
mlp = module
experts = child_module
diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py
index 517eb54d53ac6..b807f45b5d52b 100644
--- a/vllm/model_executor/models/transformers/utils.py
+++ b/vllm/model_executor/models/transformers/utils.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Literal
import torch
from torch import nn
+from transformers.configuration_utils import ALLOWED_LAYER_TYPES
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
@@ -203,5 +204,10 @@ def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
"""
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
- rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
- return rope_scaling.get("rope_type") != "dynamic"
+ rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
+ if rope_parameters:
+ # Nest rope_parameters if not nested already to simplify logic
+ if not set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
+ rope_parameters = {"": rope_parameters}
+ return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
+ return True
diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py
index 91a10b95a08c0..50587c627160d 100644
--- a/vllm/model_executor/models/whisper.py
+++ b/vllm/model_executor/models/whisper.py
@@ -599,15 +599,16 @@ class WhisperModel(nn.Module):
def forward(
self,
- input_features: torch.Tensor | list[torch.Tensor] | None,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
+ encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
- encoder_outputs = self.get_encoder_outputs(input_features)
+ assert len(encoder_outputs) in (0, 1)
+ enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
- encoder_hidden_states=encoder_outputs,
+ encoder_hidden_states=enc_states,
)
return decoder_outputs
@@ -894,13 +895,15 @@ class WhisperForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
+ encoder_outputs: list[torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
- audio_input = self._parse_and_validate_audio_input(**kwargs)
+ if encoder_outputs is None:
+ encoder_outputs = []
decoder_outputs = self.model(
- input_features=audio_input["input_features"],
input_ids=input_ids,
positions=positions,
+ encoder_outputs=encoder_outputs,
)
return decoder_outputs
diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py
index 729a9655d0879..653b5b9beef7b 100644
--- a/vllm/model_executor/models/zamba2.py
+++ b/vllm/model_executor/models/zamba2.py
@@ -128,7 +128,6 @@ class Zamba2Attention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.num_hybrid_layers = num_hybrid_layers
- self.rope_theta = config.rope_theta
self.attention_hidden_size = config.attention_hidden_size
self.total_num_attention_heads = config.num_attention_heads
@@ -233,8 +232,7 @@ class Zamba2Attention(nn.Module):
head_size=self.attention_head_dim,
rotary_dim=self.attention_head_dim,
max_position=config.max_position_embeddings,
- base=self.rope_theta,
- rope_scaling=None,
+ rope_parameters=config.rope_parameters,
is_neox_style=True,
)
diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py
index 759b809433b14..8aad59e84ff25 100644
--- a/vllm/model_executor/utils.py
+++ b/vllm/model_executor/utils.py
@@ -10,7 +10,7 @@ import torch
from vllm.utils.torch_utils import is_torch_equal_or_newer
-def set_random_seed(seed: int) -> None:
+def set_random_seed(seed: int | None) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py
index 53052ddc6343c..b93a42ffd24c1 100644
--- a/vllm/multimodal/audio.py
+++ b/vllm/multimodal/audio.py
@@ -7,6 +7,8 @@ from typing import Literal
import numpy as np
import numpy.typing as npt
+import pybase64
+import torch
from vllm.utils.import_utils import PlaceholderModule
@@ -116,3 +118,25 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")
+
+
+class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def load_bytes(self, data: bytes) -> torch.Tensor:
+ buffer = BytesIO(data)
+ return torch.load(buffer, weights_only=True)
+
+ def load_base64(self, media_type: str, data: str) -> torch.Tensor:
+ return self.load_bytes(pybase64.b64decode(data, validate=True))
+
+ def load_file(self, filepath: Path) -> torch.Tensor:
+ return torch.load(filepath, weights_only=True)
+
+ def encode_base64(self, media: torch.Tensor) -> str:
+ buffer = BytesIO()
+ torch.save(media, buffer)
+ buffer.seek(0)
+ binary_data = buffer.read()
+ return pybase64.b64encode(binary_data).decode("utf-8")
diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py
index 4a288d2d238c2..8a36ea415da4d 100644
--- a/vllm/multimodal/evs.py
+++ b/vllm/multimodal/evs.py
@@ -185,7 +185,7 @@ def recompute_mrope_positions(
Args:
input_ids: (N,) All input tokens of the prompt (entire sequence).
- multimodal_positions: List of mrope positsions for each media.
+ multimodal_positions: List of mrope positions for each media.
mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media.
diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
index 3f55c46ca334d..1020554e2e073 100644
--- a/vllm/multimodal/utils.py
+++ b/vllm/multimodal/utils.py
@@ -3,7 +3,7 @@
import asyncio
import atexit
-from collections.abc import Iterable, Set
+from collections.abc import Generator, Set
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
from pathlib import Path
@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager
-from .audio import AudioMediaIO
+from .audio import AudioEmbeddingMediaIO, AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO
@@ -342,6 +342,17 @@ class MediaConnector:
return image_embedding_io.load_base64("", data)
+ def fetch_audio_embedding(
+ self,
+ data: str,
+ ) -> torch.Tensor:
+ """
+ Load audio embedding from a URL.
+ """
+ audio_embedding_io = AudioEmbeddingMediaIO()
+
+ return audio_embedding_io.load_base64("", data)
+
def encode_audio_base64(
audio: np.ndarray,
@@ -403,7 +414,7 @@ def group_mm_kwargs_by_modality(
pin_memory: bool = False,
merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(),
-) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
+) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py
index 369c5e6cb4d10..763f90fde7b6d 100644
--- a/vllm/multimodal/video.py
+++ b/vllm/multimodal/video.py
@@ -63,6 +63,58 @@ class VideoLoader:
) -> tuple[npt.NDArray, dict[str, Any]]:
raise NotImplementedError
+ @staticmethod
+ def _read_frames(
+ cap,
+ frame_indices: set[int],
+ num_expected_frames: int,
+ max_frame_idx: int,
+ ) -> tuple[npt.NDArray, int, list[int]]:
+ import cv2
+
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)
+
+ i = 0
+ valid_frame_indices = []
+ for idx in range(max_frame_idx + 1):
+ ok = cap.grab()
+ if not ok:
+ # Frame is broken/unreadable, log warning
+ if idx in frame_indices:
+ logger.warning(
+ "Failed to grab frame %d during video loading. "
+ "This frame will be skipped.",
+ idx,
+ )
+ continue
+ if idx in frame_indices:
+ ret, frame = cap.retrieve()
+ if ret:
+ frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ valid_frame_indices.append(idx)
+ i += 1
+ else:
+ # retrieve() failed even though grab() succeeded
+ logger.warning(
+ "Failed to retrieve frame %d during video loading. "
+ "This frame will be skipped.",
+ idx,
+ )
+
+ valid_num_frames = len(valid_frame_indices)
+ if valid_num_frames < num_expected_frames:
+ logger.warning(
+ "Video loading completed with %d broken/unreadable frames. "
+ "Expected %d frames but only loaded %d frames.",
+ num_expected_frames - valid_num_frames,
+ num_expected_frames,
+ valid_num_frames,
+ )
+
+ return frames[:valid_num_frames], valid_num_frames, valid_frame_indices
+
VIDEO_LOADER_REGISTRY = ExtensionManager()
@@ -120,24 +172,10 @@ class OpenCVVideoBackend(VideoLoader):
)
frame_idx = uniform_sampled_frames.tolist()
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
-
- i = 0
- for idx in range(max(frame_idx) + 1):
- ok = cap.grab()
- if not ok:
- break
- if idx in frame_idx:
- ret, frame = cap.retrieve()
- if ret:
- frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- i += 1
-
- assert i == num_frames_to_sample, (
- f"Expected reading {num_frames_to_sample} frames, "
- f"but only loaded {i} frames from video."
+ # Convert to set for O(1) lookup performance
+ frame_idx_set = set(frame_idx)
+ frames, valid_num_frames, valid_frame_indices = cls._read_frames(
+ cap, frame_idx_set, num_frames_to_sample, max(frame_idx)
)
# Use transformers transformers.video_utils.VideoMetadata format
@@ -148,10 +186,10 @@ class OpenCVVideoBackend(VideoLoader):
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
- "frames_indices": list(frame_idx),
+ "frames_indices": valid_frame_indices,
# extra field used to control hf processor's video
# sampling behavior
- "do_sample_frames": num_frames_to_sample == total_frames_num,
+ "do_sample_frames": valid_num_frames == total_frames_num,
}
return frames, metadata
@@ -185,10 +223,10 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
# Refer to:
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
- frame_indices: range | list[int]
+ frame_indices_list: list[int]
if duration <= max_duration:
n = int(math.floor(duration * fps))
- frame_indices = sorted(
+ frame_indices_list = sorted(
{
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
for i in range(n)
@@ -197,34 +235,23 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
else:
num_samples = int(max_duration * fps)
if num_samples >= total_frames_num:
- frame_indices = range(total_frames_num)
+ frame_indices_list = list(range(total_frames_num))
else:
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
- frame_indices = sorted(
+ frame_indices_list = sorted(
{
min(max_frame_idx, int(math.ceil(t * original_fps)))
for t in target_seconds
}
)
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- frames = np.empty((len(frame_indices), height, width, 3), dtype=np.uint8)
-
- i = 0
- for idx in range(total_frames_num):
- ok = cap.grab()
- if not ok:
- break
- if idx in frame_indices:
- ret, frame = cap.retrieve()
- if ret:
- frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- i += 1
-
- assert i == len(frame_indices), (
- f"Expected reading {len(frame_indices)} frames, "
- f"but only loaded {i} frames from video."
+ # Convert to set for O(1) lookup performance
+ frame_indices_set = set(frame_indices_list)
+ frames, valid_num_frames, valid_frame_indices = cls._read_frames(
+ cap,
+ frame_indices_set,
+ len(frame_indices_list),
+ total_frames_num - 1,
)
# Use transformers transformers.video_utils.VideoMetadata format
@@ -233,7 +260,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
"fps": original_fps,
"duration": duration,
"video_backend": "opencv_dynamic",
- "frames_indices": list(frame_indices),
+ "frames_indices": valid_frame_indices,
"do_sample_frames": False,
}
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index 2e4dd8bb808b4..06793a3d1bb14 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -267,25 +267,17 @@ class CudaPlatformBase(Platform):
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
- # For Blackwell GPUs, force TORCH_SDPA for now.
- # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
- if cls.has_device_capability(100):
- return AttentionBackendEnum.TORCH_SDPA
-
- if dtype not in (torch.float16, torch.bfloat16):
- return AttentionBackendEnum.XFORMERS
-
- if cls.has_device_capability(80):
+ # Try FlashAttention first
+ try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
- else:
- return AttentionBackendEnum.XFORMERS
- else:
- # Fallback for Volta/Turing GPUs or FA not supported
- return AttentionBackendEnum.XFORMERS
+ except ImportError:
+ pass
+
+ return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 0471c20429b1d..1e6b53021f888 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -134,6 +134,11 @@ class Platform:
_global_graph_pool: Any | None = None
+ @property
+ def pass_key(self) -> str:
+ """Inductor config key for the PassManager custom pass"""
+ return "post_grad_custom_post_pass"
+
@property
def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform."""
@@ -177,6 +182,21 @@ class Platform:
# all ROCm platforms for now.
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
+ @classmethod
+ def get_pass_manager_cls(cls) -> str:
+ """
+ Get the pass manager class for this platform.
+ It will be registered as a custom pass under the current_platform.pass_key.
+ """
+ return "vllm.compilation.pass_manager.PostGradPassManager"
+
+ @classmethod
+ def get_compile_backend(cls) -> str:
+ """
+ Get the custom compile backend for current platform.
+ """
+ return cls.simple_compile_backend
+
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):
# Treat empty device control env var as unset. This is a valid
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index bb116792fed54..0483f6c06ada8 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -225,7 +225,15 @@ class RocmPlatform(Platform):
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
- raise NotImplementedError("Sparse Attention is not supported on ROCm.")
+ if kv_cache_dtype.startswith("fp8"):
+ raise ValueError(
+ "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
+ )
+ assert block_size == 1, (
+ "Sparse MLA backend on ROCm only supports block size 1 for now."
+ )
+ logger.info_once("Using Sparse MLA backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla:
if selected_backend is None:
@@ -234,7 +242,6 @@ class RocmPlatform(Platform):
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else AttentionBackendEnum.TRITON_MLA
)
-
if selected_backend == AttentionBackendEnum.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend.")
@@ -246,6 +253,9 @@ class RocmPlatform(Platform):
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
logger.info("Using AITER MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
+ logger.info("Using AITER TRITON MLA backend.")
+ return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
@@ -254,28 +264,66 @@ class RocmPlatform(Platform):
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
- return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
- if (
- rocm_aiter_ops.is_mha_enabled()
- ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
- logger.info("Using Aiter Flash Attention backend.")
- return AttentionBackendEnum.ROCM_AITER_FA.get_path()
- if (
- rocm_aiter_ops.is_triton_unified_attn_enabled()
- ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
- logger.info("Using Aiter Unified Attention backend.")
- return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
- if (
- envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
- or selected_backend == AttentionBackendEnum.ROCM_ATTN
- ):
- # rocm specific backend, with aiter and/or
- # triton prefix-prefill
- logger.info("Using Rocm Attention backend.")
+ return AttentionBackendEnum.FLEX_ATTENTION.get_path()
+
+ if selected_backend == AttentionBackendEnum.TRITON_ATTN:
+ logger.info("Using Triton Attention backend on V1 engine.")
+ return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ if selected_backend == AttentionBackendEnum.ROCM_ATTN:
+ logger.info("Using Rocm Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
- # default case, using triton unified attention
- logger.info("Using Triton Attention backend.")
- return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
+ if on_gfx9():
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+ else:
+ raise ValueError(
+ f"The selected backend, {selected_backend.name}, "
+ "is only supported on gfx9 architectures."
+ )
+
+ if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
+ logger.info("Using Aiter Unified Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
+
+ # Handle automatic backend selection based on environment variables
+ if selected_backend is None:
+ # Priority 1: Check for AITER Unified Attention (must check before MHA)
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
+ logger.info("Using Aiter Unified Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
+
+ # Priority 2: Check for AITER MHA (Flash Attention)
+ # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+
+ # Priority 3: Check for ROCM_ATTN (prefill-decode split)
+ if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
+ logger.info("Using Rocm Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_ATTN.get_path()
+
+ # Priority 4: Check for AITER enabled without specific flags
+ # This defaults to AITER FA only if MHA is not explicitly disabled
+ if (
+ envs.VLLM_ROCM_USE_AITER
+ and on_gfx9()
+ and envs.VLLM_ROCM_USE_AITER_MHA is not False
+ ):
+ logger.info("Using Aiter Flash Attention backend on V1 engine.")
+ return AttentionBackendEnum.ROCM_AITER_FA.get_path()
+
+ # Default: Triton Unified Attention
+ logger.info("Using Triton Attention backend on V1 engine.")
+ return AttentionBackendEnum.TRITON_ATTN.get_path()
+
+ raise RuntimeError(
+ f"Attention backend {selected_backend.name} is not supported on "
+ "ROCm. Note that V0 attention backends have been removed."
+ )
@classmethod
def set_device(cls, device: torch.device) -> None:
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index 65516827a16da..18a3186b142f1 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -251,10 +251,6 @@ class XPUPlatform(Platform):
) -> None:
"""Copy blocks from src_cache to dst_cache on XPU."""
_src_cache = src_cache[:, src_block_indices]
- if _src_cache.shape[2:] != dst_cache.shape[2:]:
- # To support TP_ratio, HOST KV might be initiated with HND
- # while XPU device KV is with NHD
- _src_cache = _src_cache.permute(0, 1, 3, 2, 4)
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
@classmethod
@@ -267,8 +263,4 @@ class XPUPlatform(Platform):
) -> None:
"""Copy blocks from XPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
- if _src_cache.shape[2:] != dst_cache.shape[2:]:
- # XPU device KV is with NHD while HOST KV
- # might be initiated with HND for TP_ratio support
- _src_cache = _src_cache.permute(0, 1, 3, 2, 4)
dst_cache[:, dst_block_indices] = _src_cache.cpu()
diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py
index 0d8988f27959f..4c59d5364a763 100644
--- a/vllm/plugins/__init__.py
+++ b/vllm/plugins/__init__.py
@@ -17,6 +17,9 @@ IO_PROCESSOR_PLUGINS_GROUP = "vllm.io_processor_plugins"
# Platform plugins group will be loaded in all processes when
# `vllm.platforms.current_platform` is called and the value not initialized,
PLATFORM_PLUGINS_GROUP = "vllm.platform_plugins"
+# Stat logger plugins group will be loaded in process0 only when serve vLLM with
+# async mode.
+STAT_LOGGER_PLUGINS_GROUP = "vllm.stat_logger_plugins"
# make sure one process only loads plugins once
plugins_loaded = False
diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py
index 5c3dfa8ac9cbc..d1aab98c274e1 100644
--- a/vllm/pooling_params.py
+++ b/vllm/pooling_params.py
@@ -57,7 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
- skip_reading_prefix_cache: bool = None
+ skip_reading_prefix_cache: bool | None = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
diff --git a/vllm/profiler/gpu_profiler.py b/vllm/profiler/gpu_profiler.py
index 58c6689531615..3e2cbe7296e9d 100644
--- a/vllm/profiler/gpu_profiler.py
+++ b/vllm/profiler/gpu_profiler.py
@@ -1,37 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+
+import torch
+from typing_extensions import override
+
+import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
-class CudaProfilerWrapper:
+class WorkerProfiler(ABC):
def __init__(self) -> None:
- self._profiler_running = False
+ self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS
+ if self._delay_iters > 0:
+ logger.info_once(
+ "GPU profiling will start "
+ f"{self._delay_iters} steps after start_profile."
+ )
+
+ self._max_iters = envs.VLLM_PROFILER_MAX_ITERS
+ if self._max_iters > 0:
+ logger.info_once(
+ "GPU profiling will stop "
+ f"after {self._max_iters} worker steps, "
+ "or when stop_profile is received."
+ )
+
+ # Track when the profiler gets triggered by start_profile
+ self._active_iteration_count = 0
+ self._active = False
+
+ # Track when the profiler is actually running
+ self._profiling_for_iters = 0
+ self._running = False
+
+ @abstractmethod
+ def _start(self) -> None:
+ """Start the profiler."""
+ pass
+
+ @abstractmethod
+ def _stop(self) -> None:
+ """Stop the profiler."""
+ pass
+
+ def _call_start(self) -> None:
+ """Call _start with error handling but no safeguards."""
+ try:
+ self._start()
+ self._running = True # Only mark as running if start succeeds
+ except Exception as e:
+ logger.warning("Failed to start profiler: %s", e)
+
+ def _call_stop(self) -> None:
+ """Call _stop with error handling but no safeguards."""
+ try:
+ self._stop()
+ logger.info("Profiler stopped successfully.")
+ except Exception as e:
+ logger.warning("Failed to stop profiler: %s", e)
+ self._running = False # Always mark as not running, assume stop worked
+
+ def start(self) -> None:
+ """Attempt to start the profiler, accounting for delayed starts."""
+ if self._active:
+ logger.debug(
+ "start_profile received when profiler is already active. "
+ "Ignoring request."
+ )
+ return
+ self._active = True
+ if self._delay_iters == 0:
+ self._call_start()
+
+ def step(self) -> None:
+ """Update the profiler state at each worker step,
+ to handle delayed starts and max iteration limits."""
+ if not self._active:
+ return
+
+ self._active_iteration_count += 1
+
+ if (
+ not self._running
+ and self._delay_iters > 0
+ and self._active_iteration_count == self._delay_iters
+ ):
+ logger.info("Starting profiler after delay...")
+ self._call_start()
+
+ if self._running:
+ self._profiling_for_iters += 1
+
+ if (
+ self._max_iters > 0
+ and self._running
+ and self._profiling_for_iters > self._max_iters
+ ):
+ # Automatically stop the profiler after max iters
+ # will be marked as not running, but leave as active so that stop
+ # can clean up properly
+ logger.info("Max profiling iterations reached. Stopping profiler...")
+ self._call_stop()
+ return
+
+ def stop(self) -> None:
+ """Attempt to stop the profiler, accounting for overlapped calls."""
+ if not self._active:
+ logger.debug(
+ "stop_profile received when profiler is not active. Ignoring request."
+ )
+ return
+ self._active = False
+ self._active_iteration_count = 0
+ self._profiling_for_iters = 0
+
+ if self._running:
+ self._call_stop()
+
+ def shutdown(self) -> None:
+ """Ensure profiler is stopped when shutting down."""
+ logger.info_once("Shutting down profiler")
+ if self._running:
+ self.stop()
+
+ def annotate_context_manager(self, name: str):
+ """Return a context manager to annotate profiler traces."""
+ return nullcontext()
+
+
+class TorchProfilerWrapper(WorkerProfiler):
+ def __init__(self, worker_name: str, local_rank: int) -> None:
+ super().__init__()
+
+ self.local_rank = local_rank
+ torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
+ if local_rank in (None, 0):
+ logger.info(
+ "Torch profiling enabled. Traces will be saved to: %s",
+ torch_profiler_trace_dir,
+ )
+ logger.debug(
+ "Profiler config: record_shapes=%s,"
+ "profile_memory=%s,with_stack=%s,with_flops=%s",
+ envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
+ envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
+ envs.VLLM_TORCH_PROFILER_WITH_STACK,
+ envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
+ )
+ self.profiler = torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
+ profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
+ with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
+ with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
+ torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
+ ),
+ )
+
+ @override
+ def _start(self) -> None:
+ self.profiler.start()
+
+ @override
+ def _stop(self) -> None:
+ self.profiler.stop()
+
+ rank = self.local_rank
+ profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
+ profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
+ sort_key = "self_cuda_time_total"
+ table = self.profiler.key_averages().table(sort_by=sort_key)
+
+ with open(profiler_out_file, "w") as f:
+ print(table, file=f)
+
+ # only print profiler results on rank 0
+ if rank == 0:
+ print(table)
+
+ @override
+ def annotate_context_manager(self, name: str):
+ return torch.profiler.record_function(name)
+
+
+class CudaProfilerWrapper(WorkerProfiler):
+ def __init__(self) -> None:
+ super().__init__()
# Note: lazy import to avoid dependency issues if CUDA is not available.
import torch.cuda.profiler as cuda_profiler
self._cuda_profiler = cuda_profiler
- def start(self) -> None:
- try:
- self._cuda_profiler.start()
- self._profiler_running = True
- logger.info_once("Started CUDA profiler")
- except Exception as e:
- logger.warning_once("Failed to start CUDA profiler: %s", e)
+ @override
+ def _start(self) -> None:
+ self._cuda_profiler.start()
- def stop(self) -> None:
- if self._profiler_running:
- try:
- self._cuda_profiler.stop()
- logger.info_once("Stopped CUDA profiler")
- except Exception as e:
- logger.warning_once("Failed to stop CUDA profiler: %s", e)
- finally:
- self._profiler_running = False
+ @override
+ def _stop(self) -> None:
+ self._cuda_profiler.stop()
- def shutdown(self) -> None:
- """Ensure profiler is stopped when shutting down."""
- self.stop()
+ @override
+ def annotate_context_manager(self, name: str):
+ return torch.cuda.nvtx.range(name)
diff --git a/vllm/ray/lazy_utils.py b/vllm/ray/lazy_utils.py
index 64b5f51571a35..06c91cc3943ae 100644
--- a/vllm/ray/lazy_utils.py
+++ b/vllm/ray/lazy_utils.py
@@ -10,6 +10,8 @@ def is_ray_initialized():
return ray.is_initialized()
except ImportError:
return False
+ except AttributeError:
+ return False
def is_in_ray_actor():
@@ -24,3 +26,5 @@ def is_in_ray_actor():
)
except ImportError:
return False
+ except AttributeError:
+ return False
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 0fb1d67687c82..8de961e62db1b 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -3,7 +3,6 @@
"""Sampling parameters for text generation."""
import copy
-import warnings
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
@@ -100,19 +99,6 @@ class StructuredOutputsParams:
)
-@dataclass
-class GuidedDecodingParams(StructuredOutputsParams):
- def __post_init__(self):
- warnings.warn(
- "GuidedDecodingParams is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "StructuredOutputsParams instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- return super().__post_init__()
-
-
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@@ -144,12 +130,6 @@ class SamplingParams(
are generated and streamed cumulatively per request. To see all `n`
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
in `SamplingParams`."""
- best_of: int | None = None
- """Number of output sequences that are generated from the prompt. From
- these `best_of` sequences, the top `n` sequences are returned. `best_of`
- must be greater than or equal to `n`. By default, `best_of` is set to `n`.
- Warning, this is only supported in V0."""
- _real_n: int | None = None
presence_penalty: float = 0.0
"""Penalizes new tokens based on whether they appear in the generated text
so far. Values > 0 encourage the model to use new tokens, while values < 0
@@ -240,8 +220,6 @@ class SamplingParams(
# Fields used to construct logits processors
structured_outputs: StructuredOutputsParams | None = None
"""Parameters for configuring structured outputs."""
- guided_decoding: GuidedDecodingParams | None = None
- """Deprecated alias for structured_outputs."""
logit_bias: dict[int, float] | None = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
@@ -260,12 +238,11 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
- skip_reading_prefix_cache: bool = None
+ skip_reading_prefix_cache: bool | None = None
@staticmethod
def from_optional(
n: int | None = 1,
- best_of: int | None = None,
presence_penalty: float | None = 0.0,
frequency_penalty: float | None = 0.0,
repetition_penalty: float | None = 1.0,
@@ -290,7 +267,6 @@ class SamplingParams(
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None,
- guided_decoding: GuidedDecodingParams | None = None,
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
@@ -302,20 +278,9 @@ class SamplingParams(
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
- if guided_decoding is not None:
- warnings.warn(
- "guided_decoding is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "structured_outputs instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- structured_outputs = guided_decoding
- guided_decoding = None
return SamplingParams(
n=1 if n is None else n,
- best_of=best_of,
presence_penalty=0.0 if presence_penalty is None else presence_penalty,
frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
repetition_penalty=1.0
@@ -348,22 +313,6 @@ class SamplingParams(
)
def __post_init__(self) -> None:
- # how we deal with `best_of`:
- # if `best_of` is not set, we default to `n`;
- # if `best_of` is set, we set `n` to `best_of`,
- # and set `_real_n` to the original `n`.
- # when we return the result, we will check
- # if we need to return `n` or `_real_n` results
- if self.best_of:
- if self.best_of < self.n:
- raise ValueError(
- f"best_of must be greater than or equal to n, "
- f"got n={self.n} and best_of={self.best_of}."
- )
- if not self._real_n:
- self._real_n = self.n
- self.n = self.best_of
-
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
@@ -411,17 +360,6 @@ class SamplingParams(
# eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids)
- if self.guided_decoding is not None:
- warnings.warn(
- "guided_decoding is deprecated. This will be removed in "
- "v0.12.0 or v1.0.0, which ever is soonest. Please use "
- "structured_outputs instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- self.structured_outputs = self.guided_decoding
- self.guided_decoding = None
-
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
@@ -433,18 +371,6 @@ class SamplingParams(
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
- if self.best_of is not None:
- if not isinstance(self.best_of, int):
- raise ValueError(
- f"best_of must be an integer, got {type(self.best_of)}"
- )
- if self.best_of < 1:
- raise ValueError(f"best_of must be at least 1, got {self.best_of}")
- if self.best_of < self.n:
- raise ValueError(
- f"best_of must be greater than or equal to n, "
- f"got n={self.n} and best_of={self.best_of}."
- )
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
@@ -519,10 +445,6 @@ class SamplingParams(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
- if self.best_of != self._real_n and self.output_kind == (
- RequestOutputKind.DELTA
- ):
- raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_greedy_sampling(self) -> None:
if self.n > 1:
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index ac4a71648cec8..c1880a3fba0ee 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -1,14 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import fnmatch
import json
import os
import time
from collections.abc import Callable
from dataclasses import asdict
from functools import cache, partial
+from importlib.metadata import version
from pathlib import Path
-from typing import Any, Literal, TypeVar
+from typing import Any, Literal, TypeAlias, TypeVar
import huggingface_hub
from huggingface_hub import (
@@ -24,7 +26,9 @@ from huggingface_hub.utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
)
-from transformers import DeepseekV3Config, GenerationConfig, PretrainedConfig
+from packaging.version import Version
+from transformers import GenerationConfig, PretrainedConfig
+from transformers.configuration_utils import ALLOWED_LAYER_TYPES
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
@@ -80,8 +84,9 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
afmoe="AfmoeConfig",
chatglm="ChatGLMConfig",
deepseek_vl_v2="DeepseekVLV2Config",
- deepseek_v32=DeepseekV3Config,
+ deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
+ hunyuan_vl="HunYuanVLConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
@@ -200,7 +205,19 @@ class MistralConfigParser(ConfigParserBase):
from vllm.transformers_utils.configs.mistral import adapt_config_dict
- config = adapt_config_dict(config_dict)
+ # Get missing fields from HF config if available
+ try:
+ hf_config_dict, _ = PretrainedConfig.get_config_dict(
+ model,
+ revision=revision,
+ code_revision=code_revision,
+ token=_get_hf_token(),
+ **kwargs,
+ )
+ except OSError: # Not found
+ hf_config_dict = {}
+
+ config = adapt_config_dict(config_dict, defaults=hf_config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
@@ -352,6 +369,41 @@ def list_repo_files(
return with_retry(lookup_files, "Error retrieving file list")
+def list_filtered_repo_files(
+ model_name_or_path: str,
+ allow_patterns: list[str],
+ revision: str | None = None,
+ repo_type: str | None = None,
+ token: str | bool | None = None,
+) -> list[str]:
+ try:
+ all_files = list_repo_files(
+ repo_id=model_name_or_path,
+ revision=revision,
+ token=token,
+ repo_type=repo_type,
+ )
+ except Exception:
+ logger.error(
+ "Error retrieving file list. Please ensure your `model_name_or_path`"
+ "`repo_type`, `token` and `revision` arguments are correctly set. "
+ "Returning an empty list."
+ )
+ return []
+
+ file_list = []
+ # Filter patterns on filenames
+ for pattern in allow_patterns:
+ file_list.extend(
+ [
+ file
+ for file in all_files
+ if fnmatch.fnmatch(os.path.basename(file), pattern)
+ ]
+ )
+ return file_list
+
+
def file_exists(
repo_id: str,
file_name: str,
@@ -390,21 +442,61 @@ def file_or_path_exists(
)
-def patch_rope_scaling(config: PretrainedConfig) -> None:
+def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
+ """Some models may have no rope_theta in their config but still use RoPE.
+ This function sets a default rope_theta if it's missing."""
+ if getattr(config, "rope_parameters", None) is None:
+ config.rope_parameters = {"rope_type": "default"}
+ if "rope_theta" not in config.rope_parameters:
+ config.rope_parameters["rope_theta"] = default_theta
+
+
+def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
- text_config = getattr(config, "text_config", None)
- if text_config is not None:
- patch_rope_scaling(text_config)
+ # Retrieve rope_parameters differently based on Transformers version
+ if Version(version("transformers")) >= Version("5.0.0.dev0"):
+ from transformers.modeling_rope_utils import RopeParameters
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is not None:
- patch_rope_scaling_dict(rope_scaling)
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
+ config, "rope_parameters", None
+ )
+ elif hasattr(config, "rope_parameters"):
+ # We are in Transformers v4 and rope_parameters
+ # has already been patched for this config
+ return
+ else:
+ # Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
+ rope_theta: float | None = getattr(config, "rope_theta", None)
+ rope_scaling: dict | None = getattr(config, "rope_scaling", None)
+ rope_parameters = rope_scaling
+ # Move rope_theta into rope_parameters
+ if rope_theta is not None:
+ rope_parameters = rope_parameters or {"rope_type": "default"}
+ rope_parameters["rope_theta"] = rope_theta
+ # Add original_max_position_embeddings if present
+ if rope_parameters and (
+ ompe := getattr(config, "original_max_position_embeddings", None)
+ ):
+ rope_parameters["original_max_position_embeddings"] = ompe
+ # Write back to config
+ config.rope_parameters = rope_parameters
+
+ # No RoPE parameters to patch
+ if rope_parameters is None:
+ return
+
+ # Handle nested rope_parameters in interleaved sliding attention models
+ if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
+ for rope_parameters_layer_type in rope_parameters.values():
+ patch_rope_parameters_dict(rope_parameters_layer_type)
+ else:
+ patch_rope_parameters_dict(rope_parameters)
-def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
- if "rope_type" in rope_scaling and "type" in rope_scaling:
- rope_type = rope_scaling["rope_type"]
- rope_type_legacy = rope_scaling["type"]
+def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
+ if "rope_type" in rope_parameters and "type" in rope_parameters:
+ rope_type = rope_parameters["rope_type"]
+ rope_type_legacy = rope_parameters["type"]
if rope_type != rope_type_legacy:
raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern "
@@ -412,28 +504,28 @@ def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
"You should only specify one of them."
)
- if "rope_type" not in rope_scaling and "type" in rope_scaling:
- rope_scaling["rope_type"] = rope_scaling["type"]
+ if "rope_type" not in rope_parameters and "type" in rope_parameters:
+ rope_parameters["rope_type"] = rope_parameters["type"]
logger.info("Replacing legacy 'type' key with 'rope_type'")
- if "rope_type" not in rope_scaling:
- raise ValueError("rope_scaling should have a 'rope_type' key")
+ if "rope_type" not in rope_parameters:
+ raise ValueError("rope_parameters should have a 'rope_type' key")
- if rope_scaling["rope_type"] == "su":
- rope_scaling["rope_type"] = "longrope"
+ if rope_parameters["rope_type"] == "su":
+ rope_parameters["rope_type"] = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
- elif rope_scaling["rope_type"] == "mrope":
- assert "mrope_section" in rope_scaling
- rope_scaling["rope_type"] = "default"
+ elif rope_parameters["rope_type"] == "mrope":
+ assert "mrope_section" in rope_parameters
+ rope_parameters["rope_type"] = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
def _uses_mrope(config: PretrainedConfig) -> bool:
- rope_scaling = getattr(config, "rope_scaling", None)
- if rope_scaling is None:
+ rope_parameters = getattr(config, "rope_parameters", None)
+ if rope_parameters is None:
return False
- return "mrope_section" in rope_scaling
+ return "mrope_section" in rope_parameters
def uses_mrope(config: PretrainedConfig) -> bool:
@@ -458,6 +550,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
return uses_mrope(thinker_text_config)
+def uses_xdrope_dim(config: PretrainedConfig) -> int:
+ """Detect if the model with this config uses XD-ROPE."""
+ xdrope_section = getattr(config, "xdrope_section", None)
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling is None:
+ return 0
+
+ if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
+ xdrope_section = rope_scaling["xdrope_section"]
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+
+ return 0
+
+
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
@@ -477,17 +586,6 @@ def is_interleaved(config: PretrainedConfig) -> bool:
return False
-def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
- """Detect if model uses custom attention mask generation for multimodal.
-
- Some multimodal models require custom attention masks that enable
- bidirectional attention between image tokens while maintaining causal
- attention for text tokens. Currently applies to Gemma3 multimodal models.
- """
- architectures = getattr(config, "architectures", [])
- return "Gemma3ForConditionalGeneration" in architectures
-
-
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
"""
Update kwargs for AutoConfig initialization based on model_type
@@ -587,10 +685,14 @@ def get_config(
if config_format == "auto":
try:
- if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision):
- config_format = "hf"
- elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
+ # First check for Mistral to avoid defaulting to
+ # Transformers implementation.
+ if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral"
+ elif is_gguf or file_or_path_exists(
+ model, HF_CONFIG_NAME, revision=revision
+ ):
+ config_format = "hf"
else:
raise ValueError(
"Could not detect config format for no config file found. "
@@ -690,7 +792,14 @@ def get_config(
logger.debug("Overriding HF config with %s", hf_overrides_fn)
config = hf_overrides_fn(config)
- patch_rope_scaling(config)
+ # Exhaustively patch RoPE parameters everywhere they might be
+ patch_rope_parameters(config)
+ patch_rope_parameters(config.get_text_config())
+ SubConfigs: TypeAlias = dict[str, PretrainedConfig]
+ sub_configs: SubConfigs | None = getattr(config, "sub_configs", None)
+ if sub_configs:
+ for sub_config in sub_configs:
+ patch_rope_parameters(getattr(config, sub_config))
if trust_remote_code:
maybe_register_config_serialize_by_value()
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index dcae05a15fec3..109f2b6986514 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -5,8 +5,13 @@ Model configs may be defined in this directory for the following reasons:
- There is no configuration file defined by HF Hub or Transformers library.
- There is a need to override the existing config to support vLLM.
+- The HF model_type isn't recognized by the Transformers library but can
+ be mapped to an existing Transformers config, such as
+ deepseek-ai/DeepSeek-V3.2-Exp.
"""
+from transformers import DeepseekV3Config
+
from vllm.transformers_utils.configs.afmoe import AfmoeConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
@@ -18,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLTextConfig,
+ HunYuanVLVisionConfig,
+)
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
@@ -44,9 +54,13 @@ __all__ = [
"AfmoeConfig",
"ChatGLMConfig",
"DeepseekVLV2Config",
+ "DeepseekV3Config",
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
+ "HunYuanVLConfig",
+ "HunYuanVLTextConfig",
+ "HunYuanVLVisionConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",
diff --git a/vllm/transformers_utils/configs/afmoe.py b/vllm/transformers_utils/configs/afmoe.py
index 9b634fd037a33..47fee9882f9fc 100644
--- a/vllm/transformers_utils/configs/afmoe.py
+++ b/vllm/transformers_utils/configs/afmoe.py
@@ -24,7 +24,7 @@ class AfmoeConfig(PretrainedConfig):
rms_norm_eps: float = 1e-5,
use_cache: bool = True,
tie_word_embeddings: bool = False,
- rope_theta: float = 10000.0,
+ rope_parameters: dict | None = None,
rope_scaling: dict | None = None,
num_experts: int = 64,
num_experts_per_tok: int = 6,
@@ -56,7 +56,10 @@ class AfmoeConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
+ rope_theta = kwargs.pop("rope_theta", 10000.0)
+ if rope_parameters is None:
+ rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
+ self.rope_parameters = rope_parameters
self.rope_scaling = rope_scaling
self.moe_intermediate_size = moe_intermediate_size
diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py
index 1707e15285c89..ba4b1a8f701f0 100644
--- a/vllm/transformers_utils/configs/arctic.py
+++ b/vllm/transformers_utils/configs/arctic.py
@@ -85,8 +85,15 @@ class ArcticConfig(PretrainedConfig):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 1000000.0):
- The base period of the RoPE embeddings.
+ rope_parameters (`dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_theta` (`float`): The base period of the RoPE embeddings.
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
@@ -132,7 +139,7 @@ class ArcticConfig(PretrainedConfig):
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
- rope_theta=1e6,
+ rope_parameters: dict[str, Any] | None = None,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=1,
@@ -165,7 +172,10 @@ class ArcticConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
+ rope_theta = kwargs.pop("rope_theta", 1e6)
+ if rope_parameters is None:
+ rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
+ self.rope_parameters = rope_parameters
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
diff --git a/vllm/transformers_utils/configs/flex_olmo.py b/vllm/transformers_utils/configs/flex_olmo.py
index 1f2f4d446288b..c343dc0999a87 100644
--- a/vllm/transformers_utils/configs/flex_olmo.py
+++ b/vllm/transformers_utils/configs/flex_olmo.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any
from transformers.configuration_utils import PretrainedConfig
@@ -25,8 +26,7 @@ class FlexOlmoConfig(PretrainedConfig):
bos_token_id=None,
eos_token_id=100257,
tie_word_embeddings=False,
- rope_theta=500000.0,
- rope_scaling=None,
+ rope_parameters: dict[str, Any] | None = None,
attention_bias=False,
attention_dropout=0.0,
num_experts_per_tok=5,
@@ -62,8 +62,13 @@ class FlexOlmoConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 500000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
@@ -73,5 +78,5 @@ class FlexOlmoConfig(PretrainedConfig):
self.norm_topk_prob = norm_topk_prob
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
- if self.rope_scaling is not None and "type" in self.rope_scaling:
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ if self.rope_parameters is not None and "type" in self.rope_parameters:
+ self.rope_parameters["rope_type"] = self.rope_parameters["type"]
diff --git a/vllm/transformers_utils/configs/hunyuan_vl.py b/vllm/transformers_utils/configs/hunyuan_vl.py
new file mode 100644
index 0000000000000..a826ed9b5155d
--- /dev/null
+++ b/vllm/transformers_utils/configs/hunyuan_vl.py
@@ -0,0 +1,322 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
+
+from transformers import PretrainedConfig
+
+
+class HunYuanVLVisionConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_act="gelu",
+ hidden_size=1152,
+ intermediate_size=4304,
+ interpolate_mode="bilinear",
+ rms_norm_eps=1e-05,
+ learnable_mlp_pooling_size=0,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_channels=3,
+ num_hidden_layers=27,
+ out_hidden_size=4096,
+ patch_size=16,
+ remove_prenorm=True,
+ spatial_merge_size=2,
+ temporal_patch_size=1,
+ resize_resolution=2048,
+ img_max_token_num=4096,
+ max_image_size=2048,
+ video_max_image_size=768,
+ video_min_image_size=256,
+ min_image_size=512,
+ anyres_vit_max_image_size=2048,
+ max_vit_seq_len=16384,
+ text_hidden_size=3072,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_act = hidden_act
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.interpolate_mode = interpolate_mode
+ self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
+ self.num_attention_heads = num_attention_heads
+ if not num_key_value_heads:
+ self.num_key_value_heads = num_attention_heads
+ else:
+ self.num_key_value_heads = num_key_value_heads
+ self.num_channels = num_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.out_hidden_size = out_hidden_size
+ self.patch_size = patch_size
+ self.remove_prenorm = remove_prenorm
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.rms_norm_eps = rms_norm_eps
+
+ self.resize_resolution = resize_resolution
+ self.img_max_token_num = img_max_token_num
+ self.max_image_size = max_image_size
+ self.min_image_size = min_image_size
+ self.video_max_image_size = video_max_image_size
+ self.video_min_image_size = video_min_image_size
+ self.anyres_vit_max_image_size = anyres_vit_max_image_size
+ self.max_vit_seq_len = max_vit_seq_len
+ self.text_hidden_size = text_hidden_size
+
+
+class HunYuanVLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
+ Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 290943):
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HunYuanVLTextConfig`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations or shared MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ eod_token_id (int, *optional*, defaults to 3):
+ Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
+ Example: In multi-document processing, this token helps the model distinguish between separate documents.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ """ # noqa: E501
+
+ model_type = "hunyuan_vl_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=290943,
+ hidden_size=4096,
+ intermediate_size: int = 11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ eod_token_id=3,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # self._rope_scaling_validation() # TODO: Need validation?
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and "
+ f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ "`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
+ f"got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
+ raise ValueError(
+ "`rope_scaling`'s factor or alpha field must be have one, "
+ "got both of none"
+ )
+ if rope_scaling_factor is not None and (
+ not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s factor field must be a float > 1.0, "
+ f"got {rope_scaling_factor}"
+ )
+ if rope_scaling_alpha is not None and (
+ not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s alpha field must be a float > 1.0, "
+ f"got {rope_scaling_alpha}"
+ )
+
+
+class HunYuanVLConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ sub_configs = {
+ "vision_config": HunYuanVLVisionConfig,
+ "text_config": HunYuanVLTextConfig,
+ }
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ im_start_id=120118,
+ im_end_id=120119,
+ image_token_id=120120,
+ im_newline_id=120121,
+ video_start_id=120122,
+ video_end_id=120123,
+ **kwargs,
+ ):
+ # We need to init super() here so that it does not reset values
+ # that are in text config to the BaseClass defaults. The Base
+ # config has many text related defaults and not all defaults are
+ # same as for `HunYuanVLTextConfig`.
+ super().__init__(**kwargs)
+
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ # For BC use all kwargs to init `TextConfig`
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.im_start_id = im_start_id
+ self.im_end_id = im_end_id
+ self.im_newline_id = im_newline_id
+ self.video_start_id = video_start_id
+ self.video_end_id = video_end_id
+
+ self.vision_config.text_hidden_size = self.text_config.hidden_size
+
+ # Attention implementation to use. It sets it recursively on sub-configs
+ # so we call it again in the end.
+ self._attn_implementation = kwargs.pop("attn_implementation", None)
+
+ def __setattr__(self, key, value):
+ if (
+ (text_config := super().__getattribute__("__dict__").get("text_config"))
+ is not None
+ and key not in ["dtype", "_attn_implementation_internal"]
+ and key in text_config.__dict__
+ ):
+ setattr(text_config, key, value)
+ else:
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if "text_config" in super().__getattribute__("__dict__") and key not in [
+ "_name_or_path",
+ "model_type",
+ "dtype",
+ "_attn_implementation_internal",
+ ]:
+ text_config = super().__getattribute__("text_config")
+ if key in text_config.__dict__:
+ return getattr(text_config, key)
+
+ return super().__getattribute__(key)
diff --git a/vllm/transformers_utils/configs/kimi_linear.py b/vllm/transformers_utils/configs/kimi_linear.py
index 65ddf48c5249b..14894816801d1 100644
--- a/vllm/transformers_utils/configs/kimi_linear.py
+++ b/vllm/transformers_utils/configs/kimi_linear.py
@@ -29,8 +29,7 @@ class KimiLinearConfig(PretrainedConfig):
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
- rope_theta=10000.0,
- rope_scaling=None,
+ rope_parameters=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
@@ -73,8 +72,13 @@ class KimiLinearConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 10000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
diff --git a/vllm/transformers_utils/configs/lfm2_moe.py b/vllm/transformers_utils/configs/lfm2_moe.py
index 37c038e12db80..b399a03c030f0 100644
--- a/vllm/transformers_utils/configs/lfm2_moe.py
+++ b/vllm/transformers_utils/configs/lfm2_moe.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any
from transformers.configuration_utils import PretrainedConfig
@@ -35,8 +36,8 @@ class Lfm2MoeConfig(PretrainedConfig):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 1000000.0):
- The base period of the RoPE embeddings.
+ rope_parameters (`dict`, *optional*):
+ The parameters of the RoPE embeddings.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
use_cache (`bool`, *optional*, defaults to `True`):
@@ -100,7 +101,7 @@ class Lfm2MoeConfig(PretrainedConfig):
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = True,
- rope_theta: float = 1000000.0,
+ rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 128_000,
use_cache: bool = True,
norm_eps: float = 0.00001,
@@ -121,7 +122,10 @@ class Lfm2MoeConfig(PretrainedConfig):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
- self.rope_theta = rope_theta
+ rope_theta = kwargs.pop("rope_theta", 1000000.0)
+ if rope_parameters is None:
+ rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
+ self.rope_parameters = rope_parameters
self.max_position_embeddings = max_position_embeddings
self.use_cache = use_cache
self.norm_eps = norm_eps
diff --git a/vllm/transformers_utils/configs/midashenglm.py b/vllm/transformers_utils/configs/midashenglm.py
index e49bd26b2b00c..f1bbd057103e4 100644
--- a/vllm/transformers_utils/configs/midashenglm.py
+++ b/vllm/transformers_utils/configs/midashenglm.py
@@ -98,6 +98,6 @@ class MiDashengLMConfig(PretrainedConfig):
if text_config
else Qwen2_5OmniTextConfig()
)
- self.text_config.rope_scaling = None # uses_mrope is false
+ self.text_config.rope_parameters = None # uses_mrope is false
self.audio_token_id = audio_token_id
super().__init__(**kwargs)
diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py
index c6f04febe37e1..966737aad0867 100644
--- a/vllm/transformers_utils/configs/mistral.py
+++ b/vllm/transformers_utils/configs/mistral.py
@@ -9,14 +9,18 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
-def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
- config_dict.update(kwargs)
+def adapt_config_dict(
+ config_dict: dict[str, Any],
+ defaults: dict[str, Any],
+) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
- if bool(config_dict.get("moe")):
+ if config_dict.get("model_type") == "mamba":
+ config_dict["architectures"] = ["Mamba2ForCausalLM"]
+ elif bool(config_dict.get("moe")):
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
@@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
+ for k, v in defaults.items():
+ config_dict.setdefault(k, v)
+
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config %s", config)
@@ -86,13 +93,17 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
"apply_scale": "apply_yarn_scaling",
}
yarn_config = config.get("yarn") or {}
- config["rope_scaling"] = {
+ config["rope_parameters"] = {
"rope_type": "yarn",
"mscale_all_dim": 1,
}
+
+ if rope_theta := config.pop("rope_theta", None):
+ config["rope_parameters"]["rope_theta"] = rope_theta
+
for old_name, new_name in yarn_config_map.items():
if old_name in yarn_config:
- config["rope_scaling"][new_name] = yarn_config.pop(old_name)
+ config["rope_parameters"][new_name] = yarn_config.pop(old_name)
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
@@ -114,7 +125,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False),
- "max_seq_len": ("max_seq_len", 128_000),
+ "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000),
}
diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py
index 60eed549561fb..d112c71d7d20b 100644
--- a/vllm/transformers_utils/configs/nemotron.py
+++ b/vllm/transformers_utils/configs/nemotron.py
@@ -88,8 +88,8 @@ class NemotronConfig(PretrainedConfig):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
+ rope_parameters (`dict`, *optional*):
+ The parameters of the RoPE embeddings.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
@@ -132,8 +132,7 @@ class NemotronConfig(PretrainedConfig):
bos_token_id=2,
eos_token_id=3,
tie_word_embeddings=False,
- rope_theta=10000.0,
- rope_scaling=None,
+ rope_parameters=None,
partial_rotary_factor=0.5,
attention_bias=False,
attention_dropout=0.0,
@@ -160,8 +159,13 @@ class NemotronConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 10000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
# for backward compatibility
partial_rotary_factor = (
kwargs.get("rope_percent")
@@ -169,7 +173,7 @@ class NemotronConfig(PretrainedConfig):
or partial_rotary_factor
)
self.partial_rotary_factor = partial_rotary_factor
- self._rope_scaling_validation()
+ self._rope_parameters_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
@@ -182,31 +186,29 @@ class NemotronConfig(PretrainedConfig):
**kwargs,
)
- def _rope_scaling_validation(self):
+ def _rope_parameters_validation(self):
"""
- Validate the `rope_scaling` configuration.
+ Validate the `rope_parameters` configuration.
"""
- if self.rope_scaling is None:
+ if self.rope_parameters is None:
return
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ rope_type: str | None = self.rope_parameters.get("rope_type", None)
+ factor: float | None = self.rope_parameters.get("factor", None)
+
+ if rope_type not in {"default", "linear", "dynamic"}:
raise ValueError(
- "`rope_scaling` must be a dictionary with two fields, "
- f"`type` and `factor`, got {self.rope_scaling}"
- )
- rope_scaling_type = self.rope_scaling.get("type", None)
- rope_scaling_factor = self.rope_scaling.get("factor", None)
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
- raise ValueError(
- "`rope_scaling`'s type field must be one of ['linear', "
- f"'dynamic'], got {rope_scaling_type}"
- )
- if (
- rope_scaling_factor is None
- or not isinstance(rope_scaling_factor, float)
- or rope_scaling_factor <= 1.0
- ):
- raise ValueError(
- "`rope_scaling`'s factor field must be a float > 1, got "
- f"{rope_scaling_factor}"
+ "`rope_type` must be one of ['default', 'linear', 'dynamic'], "
+ f"got {rope_type}"
)
+ if rope_type != "default":
+ if factor is None:
+ raise ValueError(
+ "If `rope_type` is not 'default', `rope_parameters` "
+ "must include a `factor` field. Got `None`."
+ )
+ if not isinstance(factor, float) or factor <= 1.0:
+ raise ValueError(
+ "`rope_parameters`'s factor field must be a float > 1, got "
+ f"{factor}"
+ )
diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py
index f5a9a7cd36bdb..c4691b661af39 100644
--- a/vllm/transformers_utils/configs/olmo3.py
+++ b/vllm/transformers_utils/configs/olmo3.py
@@ -24,8 +24,7 @@ class Olmo3Config(PretrainedConfig):
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
- rope_theta=10000.0,
- rope_scaling=None,
+ rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=1e-5,
@@ -63,8 +62,13 @@ class Olmo3Config(PretrainedConfig):
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 10000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py
index 21750bde2f878..d2fe58d48da6f 100644
--- a/vllm/transformers_utils/configs/qwen3_next.py
+++ b/vllm/transformers_utils/configs/qwen3_next.py
@@ -66,13 +66,12 @@ class Qwen3NextConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
+ rope_parameters (`dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
+ `rope_theta` (`float`): The base period of the RoPE embeddings.
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
@@ -199,8 +198,7 @@ class Qwen3NextConfig(PretrainedConfig):
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
- rope_theta=10000.0,
- rope_scaling=None,
+ rope_parameters=None,
partial_rotary_factor=0.25,
attention_bias=False,
attention_dropout=0.0,
@@ -236,8 +234,13 @@ class Qwen3NextConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 10000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
self.partial_rotary_factor = partial_rotary_factor
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py
index 637b82d88e265..0ee650a70451f 100644
--- a/vllm/transformers_utils/configs/step3_vl.py
+++ b/vllm/transformers_utils/configs/step3_vl.py
@@ -52,8 +52,7 @@ class Step3TextConfig(PretrainedConfig):
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
- rope_theta: float = 500000,
- rope_scaling: dict[str, Any] | None = None,
+ rope_parameters: dict[str, Any] | None = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
@@ -130,8 +129,13 @@ class Step3TextConfig(PretrainedConfig):
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
+ # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
+ rope_scaling = kwargs.pop("rope_scaling", None)
+ rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
+ rope_theta = kwargs.pop("rope_theta", 500000.0)
+ if "rope_theta" not in rope_parameters:
+ rope_parameters["rope_theta"] = rope_theta
+ self.rope_parameters = rope_parameters
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py
index 76b6d3dc9c99a..b49fdbe9ce776 100644
--- a/vllm/transformers_utils/processors/__init__.py
+++ b/vllm/transformers_utils/processors/__init__.py
@@ -9,7 +9,15 @@ reasons:
"""
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
-__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
+__all__ = [
+ "DeepseekVLV2Processor",
+ "HunYuanVLProcessor",
+ "HunYuanVLImageProcessor",
+ "OvisProcessor",
+ "Ovis2_5Processor",
+]
diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py
new file mode 100644
index 0000000000000..615a8bff85912
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl.py
@@ -0,0 +1,233 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
+
+import numpy as np
+import torch
+from transformers import AutoProcessor
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.video_utils import VideoInput
+
+
+class HunYuanVLProcessor(ProcessorMixin):
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ video_processor=None,
+ chat_template=None,
+ **kwargs,
+ ):
+ # TODO Fix the init
+ self.tokenizer = tokenizer
+ self.image_token_id = 120120 # self.tokenizer.image_token_id
+ self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
+ self.im_start_token_id = 120118 # self.tokenizer.im_start_id
+ self.im_start_token = self.tokenizer.convert_ids_to_tokens(
+ self.im_start_token_id
+ )
+ self.im_end_token_id = 120119 # self.tokenizer.im_end_id
+ self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
+ self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
+ self.tokenizer.vocab_size - 1
+ )
+ self.pad_id = 120002 # self.tokenizer.pad_token_id
+
+ super().__init__(
+ image_processor, tokenizer, video_processor, chat_template=chat_template
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: TextInput
+ | PreTokenizedInput
+ | list[TextInput]
+ | list[PreTokenizedInput] = None,
+ videos: VideoInput = None,
+ **kwargs,
+ ) -> BatchFeature:
+ image_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images=images)
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+
+ image_tokens_cumsum = [0]
+ if images is not None:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ grid_h, grid_w = image_grid_thw[index][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ num_image_tokens = patch_h * (patch_w + 1) + 2
+ image_tokens_cumsum.append(
+ image_tokens_cumsum[-1] + num_image_tokens
+ )
+ # text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
+ text[i] = text[i].replace(
+ self.image_token, self.placeholder_token * num_image_tokens, 1
+ )
+ index += 1
+ text[i] = text[i].replace(self.placeholder_token, self.image_token)
+ # text[i] = self.tokenizer.bos_token + text[i]
+
+ text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ input_ids = text_inputs["input_ids"]
+ position_ids = torch.arange(len(input_ids[0]))
+ position_ids_w = torch.arange(len(input_ids[0]))
+ position_ids_h = torch.arange(len(input_ids[0]))
+ position_ids_t = torch.arange(len(input_ids[0]))
+
+ if images is not None:
+ image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
+ 0
+ ]
+ for i in range(len(image_grid_thw)):
+ grid_h, grid_w = image_grid_thw[i][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
+ replace_num = (patch_w + 1) * patch_h
+ position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
+ list(range(patch_w + 1)) * patch_h, dtype=torch.int64
+ )
+ patch_h_list = []
+ for h in range(patch_h):
+ patch_h_list += [h] * (patch_w + 1)
+ position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
+ patch_h_list, dtype=torch.int64
+ )
+ position_ids_t[start_pos : start_pos + replace_num] = 0
+
+ position_ids = torch.stack(
+ [position_ids, position_ids_w, position_ids_h, position_ids_t]
+ ).unsqueeze(0)
+ text_inputs["position_ids"] = position_ids
+
+ attention_mask = input_ids.ne(self.pad_id)
+ text_inputs["attention_mask"] = attention_mask
+ text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
+ # image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
+
+ return_tensors = kwargs.pop("return_tensors", None)
+ return BatchFeature(
+ data={**text_inputs, **image_inputs},
+ tensor_type=return_tensors,
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_image_text_to_text(
+ self,
+ generated_outputs,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ assert 0
+
+ def apply_chat_template(self, *args, **kwargs):
+ token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
+ return token_ids
+
+ def get_imgs_pos(self, doc_ids):
+ doc_ids = np.array(doc_ids, dtype=np.int64)
+ img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
+ img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
+ imgs_pos = np.concatenate(
+ (
+ np.reshape(img_begin_index + 1, (-1, 1)),
+ np.reshape(img_end_index, (-1, 1)),
+ ),
+ axis=-1,
+ ).tolist()
+ return imgs_pos
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+def split_image_into_patch_blocks(
+ pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
+ patch_size: int = 16, # e.g. 16
+ adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
+) -> torch.Tensor:
+ """
+ Split the input image tensor (supporting batch) into large patches of size `patch_size`,
+ and then further divide each large patch into smaller regions of size
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
+ Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
+ The final output contains all such small region tensors.
+
+ Args:
+ pixel_values: Input image tensor of shape [batch_size, 3, H, W].
+ patch_size: Size of the large patch, e.g., 16.
+ adaptor_patch_div: Each large patch is divided into
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
+ smaller regions.
+
+ Returns:
+ patches: A tensor of shape [N, 3, patch_size, patch_size],
+ where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
+ Each element in the batch corresponds to one small image region.
+ """ # noqa: E501
+ batch_size, channels, height, width = pixel_values.shape
+ assert channels == 3, "Pixel values must have 3 channels in dim=1"
+ assert height % patch_size == 0 and width % patch_size == 0, (
+ "H and W must be divisible by patch_size"
+ )
+
+ patch_height_num = height // patch_size
+ patch_width_num = width // patch_size
+
+ # Reshape to [B, 3, ph, ps, pw, ps]
+ img = pixel_values.reshape(
+ batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
+ )
+
+ # Further split each psxps patch into (ps//aps)x(ps//aps) small regions
+ img = img.reshape(
+ batch_size,
+ 3,
+ patch_height_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ patch_width_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ )
+
+ # Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
+ img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
+
+ # Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
+ patches = img.reshape(-1, 3, patch_size, patch_size)
+
+ return patches
+
+
+AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)
diff --git a/vllm/transformers_utils/processors/hunyuan_vl_image.py b/vllm/transformers_utils/processors/hunyuan_vl_image.py
new file mode 100644
index 0000000000000..0a7e7865c783a
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl_image.py
@@ -0,0 +1,477 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
+"""Image processor class for HunYuanVL."""
+
+# isort conflicts with ruff for transformers imports
+# isort: skip_file
+import math
+
+import numpy as np
+import torchvision.transforms as transforms
+from transformers import AutoImageProcessor
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import (
+ convert_to_rgb,
+)
+from transformers.image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ make_flat_list_of_images,
+ make_list_of_images,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from transformers.utils import TensorType, logging
+from transformers.video_utils import VideoInput, make_batched_videos
+
+logger = logging.get_logger(__name__)
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ "absolute aspect ratio must be smaller than 200, got "
+ f"{max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class HunYuanVLImageProcessor(BaseImageProcessor):
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: int | float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and (
+ "shortest_edge" not in size or "longest_edge" not in size
+ ):
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ else:
+ size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ # hard-code
+
+ def _preprocess(
+ self,
+ images: ImageInput | VideoInput,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ do_convert_rgb: bool | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """ # noqa: E501
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ width, height = images[0].width, images[0].height
+ resized_width, resized_height = width, height
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_width, resized_height = smart_resize(
+ width,
+ height,
+ factor=patch_size * merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ if do_normalize:
+ image = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(self.image_mean, self.image_std),
+ ]
+ )(image)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ patches = patches.reshape(
+ 1,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
+ flatten_patches = patches.reshape(
+ 1 * grid_h * grid_w, channel * patch_size * patch_size
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int | None = None,
+ temporal_patch_size: int | None = None,
+ merge_size: int | None = None,
+ do_convert_rgb: bool | None = None,
+ return_tensors: str | TensorType | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ videos (`VideoInput`):
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """ # noqa: E501
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ min_pixels = size["shortest_edge"]
+ elif min_pixels is not None and max_pixels is not None:
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ else:
+ size = {**self.size}
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = (
+ rescale_factor if rescale_factor is not None else self.rescale_factor
+ )
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = (
+ temporal_patch_size
+ if temporal_patch_size is not None
+ else self.temporal_patch_size
+ )
+ merge_size = merge_size if merge_size is not None else self.merge_size
+ do_convert_rgb = (
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ )
+
+ if images is not None:
+ images = make_flat_list_of_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ data = {}
+ if images is not None:
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data.update(
+ {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
+ )
+
+ # kept for BC only and should be removed after v5.0
+ if videos is not None:
+ logger.warning(
+ "`HunYuanVLV1ImageProcessor` works only with image inputs "
+ "and doesn't process videos anymore. "
+ "This is a deprecated behavior and will be removed in v5.0. "
+ "Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
+ )
+ videos = make_batched_videos(videos)
+ pixel_values_videos, vision_grid_thws_videos = [], []
+ for images in videos:
+ patches, video_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values_videos.extend(patches)
+ vision_grid_thws_videos.append(video_grid_thw)
+ data.update(
+ {
+ "pixel_values_videos": np.array(pixel_values_videos),
+ "video_grid_thw": np.array(vision_grid_thws_videos),
+ }
+ )
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*):
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = (
+ images_kwargs["min_pixels"]
+ if "min_pixels" in images_kwargs
+ else self.size["shortest_edge"]
+ )
+ max_pixels = (
+ images_kwargs["max_pixels"]
+ if "max_pixels" in images_kwargs
+ else self.size["longest_edge"]
+ )
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * (grid_w + 1) + 2
+
+
+AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index a393568909d27..233076741503d 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -3,8 +3,8 @@
import contextlib
import copy
+import importlib.util
import os
-import warnings
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
@@ -15,7 +15,10 @@ from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
-from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
+from vllm.transformers_utils.config import (
+ get_sentence_transformer_tokenizer_config,
+ list_filtered_repo_files,
+)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
@@ -182,25 +185,29 @@ def get_tokenizer(
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent
- # if tokenizer is from official mistral org
- is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
- if is_from_mistral_org and tokenizer_mode != "mistral":
- warnings.warn(
- "It is strongly recommended to run mistral models with "
- '`--tokenizer-mode "mistral"` to ensure correct '
- "encoding and decoding.",
- FutureWarning,
- stacklevel=2,
+ # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
+ # first to use official Mistral tokenizer if possible.
+ mistral_common_installed = importlib.util.find_spec("mistral_common") is not None
+ if tokenizer_mode == "auto" and mistral_common_installed:
+ allow_patterns = ["tekken.json", "tokenizer.model.v*"]
+ files_list = list_filtered_repo_files(
+ model_name_or_path=str(tokenizer_name),
+ allow_patterns=allow_patterns,
+ revision=revision,
)
+ if len(files_list) > 0:
+ tokenizer_mode = "mistral"
tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
+ logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision
)
elif tokenizer_mode == "custom":
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
+ logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name),
*args,
@@ -210,6 +217,7 @@ def get_tokenizer(
)
else:
try:
+ logger.debug_once(f"Loading AutoTokenizer from {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py
index 3ef44e7703204..fddcc27204307 100644
--- a/vllm/utils/__init__.py
+++ b/vllm/utils/__init__.py
@@ -49,13 +49,14 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
-STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
+MASK_64_BITS = (1 << 64) - 1
+
def random_uuid() -> str:
- return str(uuid.uuid4().hex)
+ return f"{uuid.uuid4().int & MASK_64_BITS:016x}" # 16 hex chars
def length_from_prompt_token_ids_or_embeds(
diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py
index b5ab37534dd78..b25c1e3e1ece3 100644
--- a/vllm/utils/deep_gemm.py
+++ b/vllm/utils/deep_gemm.py
@@ -325,6 +325,7 @@ DEFAULT_BLOCK_SIZE = [128, 128]
def per_block_cast_to_fp8(
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
+ fp8_dtype = current_platform.fp8_dtype()
assert x.dim() == 2
m, n = x.shape
block_m, block_n = block_size
@@ -334,9 +335,9 @@ def per_block_cast_to_fp8(
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
- sf = x_amax / 448.0
+ sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
- x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
+ x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2)
)
@@ -365,11 +366,18 @@ def should_use_deepgemm_for_fp8_linear(
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
+
+ # Verify DeepGEMM N/K dims requirements
+ # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
+ # test inside kernels/quatization/test_block_fp8.py
+ N_MULTIPLE = 64
+ K_MULTIPLE = 128
+
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
- and weight.shape[0] % 128 == 0
- and weight.shape[1] % 128 == 0
+ and weight.shape[0] % N_MULTIPLE == 0
+ and weight.shape[1] % K_MULTIPLE == 0
)
diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py
index 1209d64901bf5..9f9976d52b4ae 100644
--- a/vllm/utils/flashinfer.py
+++ b/vllm/utils/flashinfer.py
@@ -114,7 +114,17 @@ flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "cutlass_fused_moe"
)
+flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
+ "flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
+)
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
+nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
+silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
+ "flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
+)
+scaled_fp4_grouped_quantize = _lazy_import_wrapper(
+ "flashinfer", "scaled_fp4_grouped_quantize"
+)
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave"
)
@@ -166,6 +176,14 @@ def has_flashinfer_moe() -> bool:
)
+@functools.cache
+def has_flashinfer_cutedsl() -> bool:
+ """Return ``True`` if FlashInfer cutedsl module is available."""
+ return (
+ has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
+ )
+
+
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
@@ -187,6 +205,26 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
return True
+@functools.cache
+def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
+ """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
+ if not has_flashinfer_cutedsl():
+ return False
+
+ # Check if all required functions are available
+ required_functions = [
+ ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
+ ("flashinfer", "scaled_fp4_grouped_quantize"),
+ ("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"),
+ ]
+
+ for module_name, attr_name in required_functions:
+ mod = _get_submodule(module_name)
+ if not mod or not hasattr(mod, attr_name):
+ return False
+ return True
+
+
@functools.cache
def has_nvidia_artifactory() -> bool:
"""Return `True` if NVIDIA's artifactory is accessible.
@@ -472,7 +510,10 @@ __all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
+ "flashinfer_cutedsl_grouped_gemm_nt_masked",
"flashinfer_fp4_quantize",
+ "silu_and_mul_scaled_nvfp4_experts_quantize",
+ "scaled_fp4_grouped_quantize",
"nvfp4_block_scale_interleave",
"trtllm_fp4_block_scale_moe",
"autotune",
@@ -480,6 +521,7 @@ __all__ = [
"has_flashinfer_comm",
"has_flashinfer_all2all",
"has_flashinfer_cutlass_fused_moe",
+ "has_flashinfer_cutedsl_grouped_gemm_nt_masked",
"has_nvidia_artifactory",
"supports_trtllm_attention",
"can_use_trtllm_attention",
diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py
index 160ac9ac263a9..c56b1794230e9 100644
--- a/vllm/utils/gc_utils.py
+++ b/vllm/utils/gc_utils.py
@@ -53,6 +53,7 @@ class GCDebugger:
self.config = config
# Start time in micro second of this GC cycle
self.start_time_ns: int = time.monotonic_ns()
+ self.num_objects: int = 0
# If config.top_objects is positive,
# compute top collected objects by object types
self.gc_top_collected_objects: str = ""
@@ -68,8 +69,10 @@ class GCDebugger:
# Before GC started, record GC start time
# and top collected objects
self.start_time_ns = time.monotonic_ns()
+ objects = gc.get_objects(generation)
+ self.num_objects = len(objects)
self.gc_top_collected_objects = _compute_top_gc_collected_objects(
- gc.get_objects(generation), self.config.top_objects
+ objects, self.config.top_objects
)
elif phase == "stop":
# After GC finished, Record GC elapsed time and
@@ -77,9 +80,10 @@ class GCDebugger:
elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6
logger.info(
"GC took %.3fms to complete. "
- "Collected %s objects in GC generation %d.%s",
+ "Collected %s objects (out of %d) in GC generation %d.%s",
elpased_ms,
str(info.get("collected", "?")),
+ self.num_objects,
generation,
(
f" Top collected objects: \n{self.gc_top_collected_objects}"
diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py
index 5968884e232a4..cc872040b6c5f 100644
--- a/vllm/utils/system_utils.py
+++ b/vllm/utils/system_utils.py
@@ -22,7 +22,7 @@ from .platform_utils import cuda_is_initialized, xpu_is_initialized
logger = init_logger(__name__)
-CYAN = "\033[1;36m"
+CYAN = "\033[0;36m"
RESET = "\033[0;0m"
@@ -142,7 +142,10 @@ def set_process_title(
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
"""Add colored prefix to file output for log decoration."""
- prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
+ if envs.NO_COLOR:
+ prefix = f"({worker_name} pid={pid}) "
+ else:
+ prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
file_write = file.write
def write_with_prefix(s: str):
diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py
index 7c094e14cff72..3661dfd09047a 100644
--- a/vllm/utils/torch_utils.py
+++ b/vllm/utils/torch_utils.py
@@ -426,8 +426,7 @@ def aux_stream() -> torch.cuda.Stream | None:
from vllm.platforms import current_platform
- # TODO: validate this works properly on ROCm platform.
- if _aux_stream is None and current_platform.is_cuda():
+ if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream
diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py
index f1254352c0585..590bf91b0d057 100644
--- a/vllm/v1/attention/backends/cpu_attn.py
+++ b/vllm/v1/attention/backends/cpu_attn.py
@@ -25,7 +25,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
-_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86,)
+_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM)
class CPUAttentionBackend(AttentionBackend):
@@ -491,6 +491,9 @@ def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
elif block_size % 32 == 0:
- return "vec"
+ if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
+ return "neon"
+ else:
+ return "vec"
else:
return "vec16"
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index a5d4435000d4d..a9a4af5ac1183 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
-from vllm.config import VllmConfig, get_layers_from_vllm_config
+from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
@@ -56,11 +56,26 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- # NOTE(tdoublep): while in principle, FA supports
- # MultipleOf(16), these are the block sizes that do not
- # suffer from the NaN propagation problem described here:
- # https://github.com/Dao-AILab/flash-attention/issues/1974
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
+
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ vllm_config = get_current_vllm_config()
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ if (
+ model_config
+ and model_config.is_hybrid
+ and (
+ cache_config.mamba_ssm_cache_dtype == "float32"
+ or cache_config.mamba_cache_dtype == "float32"
+ )
+ ):
+ # NOTE(tdoublep): while in principle, FA supports
+ # MultipleOf(16), these are the block sizes that do not
+ # suffer from the NaN propagation problem described here:
+ # https://github.com/Dao-AILab/flash-attention/issues/1974
+ return [16, 32, 64]
+ return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
@@ -99,12 +114,20 @@ class FlashAttentionBackend(AttentionBackend):
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
- def get_kv_cache_stride_order() -> tuple[int, ...]:
+ def get_kv_cache_stride_order(
+ include_num_layers_dimension: bool = False,
+ ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
- if cache_layout == "NHD":
+ if cache_layout == "NHD" and include_num_layers_dimension:
+ # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
+ return (2, 0, 1, 3, 4, 5)
+ elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
+ elif cache_layout == "HND" and include_num_layers_dimension:
+ # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
+ return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
@@ -119,8 +142,8 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
- def get_supported_head_sizes(cls) -> list[int]:
- return [32, 64, 96, 128, 160, 192, 224, 256]
+ def supports_head_size(cls, head_size: int) -> bool:
+ return head_size % 8 == 0 and head_size <= 256
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
@@ -265,8 +288,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.dcp_world_size = 1
self.dcp_rank = 0
- self.dcp_kv_cache_interleave_size = (
- self.parallel_config.dcp_kv_cache_interleave_size
+ self.cp_kv_cache_interleave_size = (
+ self.parallel_config.cp_kv_cache_interleave_size
)
self.use_full_cuda_graph = (
@@ -388,7 +411,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
- self.dcp_kv_cache_interleave_size,
+ self.cp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index 4da1637d96eb6..8159f4096107f 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -16,7 +16,6 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
-from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
@@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- # Note: Not sure for all platforms,
- # but on Blackwell, only support a page size of
- # 16, 32, 64
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
@@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
"fp8_e5m2",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ # Note: Not sure for all platforms, but on Blackwell,
+ # only support a page size of 16, 32, 64.
+ return [16, 32, 64]
+
@staticmethod
def get_name() -> str:
return "FLASHINFER"
@@ -309,12 +310,20 @@ class FlashInferBackend(AttentionBackend):
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
- def get_kv_cache_stride_order() -> tuple[int, ...]:
+ def get_kv_cache_stride_order(
+ include_num_layers_dimension: bool = False,
+ ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
- if cache_layout == "NHD":
+ if cache_layout == "NHD" and include_num_layers_dimension:
+ # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
+ return (1, 0, 2, 3, 4, 5)
+ elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
+ elif cache_layout == "HND" and include_num_layers_dimension:
+ # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
+ return (1, 2, 4, 0, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
@@ -558,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
@classmethod
- @override
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,
@@ -585,6 +593,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return self._workspace_buffer
+ def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
+ self._workspace_buffer = workspace_buffer
+
def _get_prefill_wrapper(
self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py
index 1900c50849eca..004baa2d09cde 100644
--- a/vllm/v1/attention/backends/linear_attn.py
+++ b/vllm/v1/attention/backends/linear_attn.py
@@ -7,6 +7,7 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
+ AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
@@ -35,6 +36,8 @@ class LinearAttentionMetadata:
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
+ _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
+
def __init__(
self,
kv_cache_spec: AttentionSpec,
diff --git a/vllm/v1/attention/backends/mla/aiter_triton_mla.py b/vllm/v1/attention/backends/mla/aiter_triton_mla.py
new file mode 100644
index 0000000000000..8a92152a0ca53
--- /dev/null
+++ b/vllm/v1/attention/backends/mla/aiter_triton_mla.py
@@ -0,0 +1,74 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from vllm.v1.attention.backends.mla.common import MLACommonBackend
+from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
+ AiterMLAImpl,
+ AiterMLAMetadataBuilder,
+)
+
+
+class AiterTritonMLABackend(MLACommonBackend):
+ @staticmethod
+ def get_name() -> str:
+ return "AITER_TRITON_MLA"
+
+ @staticmethod
+ def get_impl_cls() -> type["AiterTritonMLAImpl"]:
+ return AiterTritonMLAImpl
+
+ @staticmethod
+ def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
+ return AiterMLAMetadataBuilder
+
+
+class AiterTritonMLAImpl(AiterMLAImpl):
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: int,
+ alibi_slopes: list[float] | None,
+ sliding_window: int | None,
+ kv_cache_dtype: str,
+ logits_soft_cap: float | None,
+ attn_type: str,
+ kv_sharing_target_layer_name: str | None,
+ # MLA Specific Arguments
+ **mla_args,
+ ) -> None:
+ super().__init__(
+ num_heads,
+ head_size,
+ scale,
+ num_kv_heads,
+ alibi_slopes,
+ sliding_window,
+ kv_cache_dtype,
+ logits_soft_cap,
+ attn_type,
+ kv_sharing_target_layer_name,
+ **mla_args,
+ )
+ from aiter.ops.triton.mha import flash_attn_varlen_func
+
+ self.flash_attn_varlen_func = flash_attn_varlen_func
+
+ def _flash_attn_varlen_diff_headdims(
+ self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
+ ):
+ result = self.flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ softmax_scale=softmax_scale,
+ return_lse=return_softmax_lse,
+ **kwargs,
+ )
+ # Transpose the LSE if Triton MHA is used:
+ # (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0])
+ if type(result) is tuple and return_softmax_lse:
+ output, lse = result
+ lse = lse.T.contiguous()
+ return (output, lse)
+ return result
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index 2ccdd1f143ce8..87a3aac21d2c3 100755
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -308,6 +308,15 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
+ @staticmethod
+ def get_kv_cache_stride_order(
+ include_num_layers_dimension: bool = False,
+ ) -> tuple[int, ...]:
+ # `stride_order` indicates the permutation that gets
+ # us from `get_kv_cache_shape` to the actual memory layout we want.
+ # (num_blocks, num_layers, block_size, head_size)
+ return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
+
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@@ -331,6 +340,8 @@ class MLACommonPrefillMetadata:
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor
+ token_to_seq: torch.Tensor
+ chunk_total_token: list[int]
# for mla DCP
padded_local_chunk_seq_lens: list[list[int]] | None = None
@@ -536,7 +547,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
- self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
+ self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
# Don't try to access the runner on AMD
@@ -755,6 +766,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
+ dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
@@ -829,6 +841,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
+ chunk_total_token = cu_seq_lens_cpu[:, -1]
+
+ max_token_num_over_chunk = chunk_total_token.max().item()
+ token_to_seq_tensor_cpu = torch.zeros(
+ [num_chunks, max_token_num_over_chunk], dtype=torch.int32
+ )
+ range_idx = torch.arange(num_prefills, dtype=torch.int32)
+ for i in range(num_chunks):
+ chunk_token_to_seq_tensor = torch.repeat_interleave(
+ range_idx, chunk_seq_lens[i]
+ )
+ chunk_len = chunk_token_to_seq_tensor.shape[0]
+ token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
if self.dcp_world_size > 1:
local_context_lens_allranks = get_dcp_local_seq_lens(
@@ -896,6 +921,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
+ token_to_seq=token_to_seq_tensor_cpu.to(
+ device, non_blocking=True
+ ),
+ chunk_total_token=chunk_total_token.tolist(),
workspace=self.chunked_prefill_workspace,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
@@ -912,6 +941,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
+ token_to_seq=token_to_seq_tensor_cpu.to(
+ device, non_blocking=True
+ ),
+ chunk_total_token=chunk_total_token,
workspace=self.chunked_prefill_workspace,
)
@@ -944,18 +977,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = None
if num_decodes > 0:
+ dcp_tot_seq_lens_device = None
+ if self.dcp_world_size > 1:
+ dcp_tot_seq_lens_device = seq_lens[:num_decodes]
+ seq_lens_cpu = dcp_local_seq_lens_cpu
+ seq_lens = dcp_local_seq_lens
+
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
- seq_lens_device=dcp_local_seq_lens[:num_decodes]
- if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
- else seq_lens[:num_decodes],
+ seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
- dcp_tot_seq_lens_device=seq_lens[:num_decodes]
- if self.dcp_world_size > 1
- else None,
+ dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
)
attn_metadata = self.metadata_cls(
@@ -1286,8 +1321,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
get_current_vllm_config()
)
)
- self.dcp_kv_cache_interleave_size: int = (
- get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
+ self.cp_kv_cache_interleave_size: int = (
+ get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
)
def _flash_attn_varlen_diff_headdims(
@@ -1626,16 +1661,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
-
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
-
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
- batch_size=attn_metadata.num_prefills,
+ token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
+ num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py
index 60cb5022a55eb..5e3fbc0abf083 100644
--- a/vllm/v1/attention/backends/mla/cutlass_mla.py
+++ b/vllm/v1/attention/backends/mla/cutlass_mla.py
@@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [128]
+
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py
index 7794e89cc0a94..d369814c10b6f 100644
--- a/vllm/v1/attention/backends/mla/flashattn_mla.py
+++ b/vllm/v1/attention/backends/mla/flashattn_mla.py
@@ -41,9 +41,12 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
+
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
@@ -173,7 +176,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
- max_seq_len = seq_lens_device.max().item()
+ max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py
index 52bb19e039e45..f02a4bb1ef35a 100644
--- a/vllm/v1/attention/backends/mla/flashinfer_mla.py
+++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py
@@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [32, 64]
+
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py
index 3aab1f9bb7fb6..74a4cd8430250 100644
--- a/vllm/v1/attention/backends/mla/flashmla.py
+++ b/vllm/v1/attention/backends/mla/flashmla.py
@@ -39,13 +39,16 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [64]
+
@staticmethod
def get_name() -> str:
return "FLASHMLA"
diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py
index bb8d914d15719..1eee1d225293b 100644
--- a/vllm/v1/attention/backends/mla/flashmla_sparse.py
+++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py
@@ -55,9 +55,12 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [64]
+
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"
@@ -168,7 +171,7 @@ def _convert_req_index_to_global_index_kernel(
inblock_off = tok % BLOCK_SIZE
# Guard block_table access
- valid_block = block_id < max_num_blocks_per_req
+ valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)
diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py
index 37aa5dad89a0e..77f1ba00d5b04 100644
--- a/vllm/v1/attention/backends/mla/indexer.py
+++ b/vllm/v1/attention/backends/mla/indexer.py
@@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
-from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
+from vllm.platforms import current_platform
+from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
@@ -23,7 +24,9 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [1 if current_platform.is_rocm() else 64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
@@ -45,7 +48,11 @@ class DeepseekV32IndexerBackend(AttentionBackend):
return (num_blocks, block_size, head_size)
@staticmethod
- def get_kv_cache_stride_order() -> tuple[int, ...]:
+ def get_kv_cache_stride_order(
+ include_num_layers_dimension: bool = False,
+ ) -> tuple[int, ...]:
+ if include_num_layers_dimension:
+ return (0, 1, 2, 3)
return (0, 1, 2)
@@ -328,10 +335,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
-
- self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
- seq_lens, self.kv_cache_spec.block_size, self.num_sms
- )
+ if is_deep_gemm_supported():
+ self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
+ seq_lens, self.kv_cache_spec.block_size, self.num_sms
+ )
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
index e1864526f02cc..00a0a77a1c2f7 100644
--- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
@@ -7,9 +7,8 @@ from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
-from vllm.attention.backends.abstract import AttentionLayer
+from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
from vllm.config import VllmConfig
-from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
@@ -22,6 +21,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [1]
+
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@@ -46,6 +49,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
paged_kv_last_page_len: torch.Tensor | None = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None
+ # The dtype of MLA out tensor
+ attn_out_dtype: torch.dtype = torch.bfloat16
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@@ -71,9 +76,9 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
)
self.compilation_config = vllm_config.compilation_config
- max_num_pages_per_req = cdiv(
- vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
- )
+ self.decode_attn_out_dtype = vllm_config.model_config.dtype
+ # kernel block size is always 1.
+ max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
@@ -82,11 +87,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
- self.block_table_remapping = torch.zeros(
- [max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
- dtype=torch.int32,
- device=device,
- )
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
@@ -111,36 +111,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata:
- page_size = self.kv_cache_spec.block_size
+ # kernel block size is always 1, although the kv block size is not 1.
device = self.device
num_reqs = seq_lens_device.size(0)
- bs, _ = block_table_tensor.shape
- block_table_tensor = (
- block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
- )
- block_table_tensor = (
- block_table_tensor
- + torch.arange(
- 0,
- page_size,
- device=block_table_tensor.device,
- dtype=block_table_tensor.dtype,
- )[None, None, :]
- )
- block_table_tensor = block_table_tensor.view(bs, -1)
- # after remapping, we assume the block size already equals to 1
-
- max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask]
- paged_kv_last_page_len = seq_lens_device % page_size
- paged_kv_last_page_len = torch.where(
- paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
- )
+ paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
paged_kv_indptr = torch.cat(
[
@@ -151,12 +131,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
- self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
- block_table_tensor, non_blocking=True
- )
- block_table_tensor = self.block_table_remapping[
- :num_reqs, :max_blk_size_per_req
- ]
self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True
@@ -191,6 +165,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
+ attn_out_dtype=self.decode_attn_out_dtype,
)
return attn_metadata
@@ -271,7 +246,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(
- B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
+ B,
+ self.num_heads,
+ self.kv_lora_rank,
+ dtype=attn_metadata.decode.attn_out_dtype,
+ device=q.device,
)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
@@ -289,6 +268,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
+ q_scale=layer._q_scale,
+ kv_scale=layer._k_scale,
)
return o, None
diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
new file mode 100644
index 0000000000000..c0e7f0e380b98
--- /dev/null
+++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
@@ -0,0 +1,325 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, ClassVar, Optional
+
+import numpy as np
+import torch
+
+from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
+from vllm.attention.backends.abstract import (
+ AttentionBackend,
+ AttentionLayer,
+ AttentionMetadata,
+)
+from vllm.attention.backends.utils import get_mla_dims
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.v1.attention.backends.mla.common import (
+ MLACommonBaseImpl,
+)
+from vllm.v1.attention.backends.mla.flashmla_sparse import (
+ triton_convert_req_index_to_global_index,
+)
+from vllm.v1.attention.backends.utils import (
+ AttentionCGSupport,
+ AttentionMetadataBuilder,
+ CommonAttentionMetadata,
+)
+from vllm.v1.kv_cache_interface import AttentionSpec
+
+if TYPE_CHECKING:
+ from vllm.model_executor.models.deepseek_v2 import Indexer
+logger = init_logger(__name__)
+
+
+class ROCMAiterMLASparseBackend(AttentionBackend):
+ accept_output_buffer: bool = True
+
+ @staticmethod
+ def get_name() -> str:
+ return "ROCM_AITER_MLA_SPARSE"
+
+ @staticmethod
+ def get_metadata_cls() -> type[AttentionMetadata]:
+ return ROCMAiterMLASparseMetadata
+
+ @staticmethod
+ def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
+ return ROCMAiterMLASparseMetadataBuilder
+
+ @staticmethod
+ def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
+ return ROCMAiterMLASparseImpl
+
+ @staticmethod
+ def get_kv_cache_shape(
+ num_blocks: int,
+ block_size: int,
+ num_kv_heads: int, # assumed to be 1 for MLA
+ head_size: int,
+ cache_dtype_str: str = "auto",
+ ) -> tuple[int, ...]:
+ return (num_blocks, block_size, head_size)
+
+ @classmethod
+ def get_supported_dtypes(cls) -> list[torch.dtype]:
+ return [torch.bfloat16]
+
+ @classmethod
+ def get_supported_head_sizes(cls) -> list[int]:
+ return [576]
+
+
+@dataclass
+class ROCMAiterMLASparseMetadata:
+ num_reqs: int
+ max_query_len: int
+ max_seq_len: int
+
+ num_actual_tokens: int # Number of tokens excluding padding.
+ query_start_loc: torch.Tensor
+ slot_mapping: torch.Tensor
+
+ block_table: torch.Tensor
+ req_id_per_token: torch.Tensor
+ block_size: int = 1
+ topk_tokens: int = 2048
+
+
+@dataclass
+class ROCMAiterMLASparseMetadataBuilder(
+ AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
+):
+ cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
+
+ def __init__(
+ self,
+ kv_cache_spec: AttentionSpec,
+ layer_names: list[str],
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ self.kv_cache_spec = kv_cache_spec
+ self.model_config = vllm_config.model_config
+ parallel_config = vllm_config.parallel_config
+ self.device = device
+
+ self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
+ self.mla_dims = get_mla_dims(self.model_config)
+ self.topk_tokens = vllm_config.model_config.hf_config.index_topk
+ self.topk_tokens_tensor = torch.tensor(
+ [self.topk_tokens], device=device, dtype=torch.int32
+ )
+ self.max_model_len_tensor = torch.tensor(
+ [self.model_config.max_model_len], device=device, dtype=torch.int32
+ )
+ # this is ignored by `flash_mla_with_kvcache` if indices not None
+ self.dummy_block_table = torch.empty(
+ (1, 1), dtype=torch.int32, device=self.device
+ )
+
+ self.req_id_per_token_buffer = torch.empty(
+ (vllm_config.scheduler_config.max_num_batched_tokens,),
+ dtype=torch.int32,
+ device=device,
+ )
+
+ def build(
+ self,
+ common_prefix_len: int,
+ common_attn_metadata: CommonAttentionMetadata,
+ fast_build: bool = False,
+ ) -> ROCMAiterMLASparseMetadata:
+ num_tokens = common_attn_metadata.num_actual_tokens
+ starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
+ seg_lengths = np.diff(starts)
+ req_id_per_token = np.repeat(
+ np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
+ )
+ # Zero-fill for cudagraphs
+ self.req_id_per_token_buffer.fill_(0)
+ self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
+ torch.from_numpy(req_id_per_token), non_blocking=True
+ )
+ req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
+
+ metadata = ROCMAiterMLASparseMetadata(
+ num_reqs=common_attn_metadata.num_reqs,
+ max_query_len=common_attn_metadata.max_query_len,
+ max_seq_len=common_attn_metadata.max_seq_len,
+ num_actual_tokens=common_attn_metadata.num_actual_tokens,
+ query_start_loc=common_attn_metadata.query_start_loc,
+ slot_mapping=common_attn_metadata.slot_mapping,
+ block_table=common_attn_metadata.block_table_tensor,
+ req_id_per_token=req_id_per_token,
+ block_size=self.kv_cache_spec.block_size,
+ topk_tokens=self.topk_tokens,
+ )
+ return metadata
+
+
+# Take from
+# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72
+def reference_mla_sparse_prefill(
+ q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
+) -> tuple[torch.Tensor, torch.Tensor]:
+ import math
+
+ def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
+ return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
+
+ skv = kv.shape[0]
+ sq = q.shape[0]
+ topk = indices.shape[-1]
+ dqk = q.shape[-1]
+ indices = indices[:, 0, :] # [s_q, topk]
+ invalid_indices_mask = (indices < 0) | (indices >= skv)
+ indices[invalid_indices_mask] = 0
+ qs = q # [s_q, h_q, d_qk]
+ kvs = kv[:, 0, :][indices].view(sq, topk, dqk) # [s_q, topk, d_qk]
+
+ attn_score = (qs @ kvs.transpose(1, 2)).float() # [s_q, h_q, topk]
+ attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
+ attn_score *= sm_scale * math.log2(math.e)
+ lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
+ attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
+ result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
+ return (result, lse)
+
+
+class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: int,
+ alibi_slopes: list[float] | None,
+ sliding_window: int | None,
+ kv_cache_dtype: str,
+ logits_soft_cap: float | None,
+ attn_type: str,
+ kv_sharing_target_layer_name: str | None,
+ # MLA Specific Arguments
+ topk_indice_buffer: torch.Tensor | None = None,
+ indexer: Optional["Indexer"] = None,
+ **mla_args,
+ ) -> None:
+ super().__init__(
+ num_heads,
+ head_size,
+ scale,
+ num_kv_heads,
+ alibi_slopes,
+ sliding_window,
+ kv_cache_dtype,
+ logits_soft_cap,
+ attn_type,
+ kv_sharing_target_layer_name,
+ **mla_args,
+ )
+ self.softmax_scale = scale
+ assert indexer is not None
+ self.topk_indices_buffer = indexer.topk_indices_buffer
+ self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
+
+ def _forward_bf16_kv(
+ self,
+ q: torch.Tensor,
+ kv_c_and_k_pe_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ attn_metadata: ROCMAiterMLASparseMetadata,
+ ) -> torch.Tensor:
+ num_tokens = q.shape[0]
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
+ -1, 1, kv_c_and_k_pe_cache.shape[-1]
+ )
+
+ topk_indices = topk_indices.view(num_tokens, 1, -1)
+ output = reference_mla_sparse_prefill(
+ q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
+ )[0]
+ return output[:, : self.num_heads, :]
+
+ def forward(
+ self,
+ layer: AttentionLayer,
+ q: torch.Tensor,
+ k_c_normed: torch.Tensor, # key in unified attn
+ k_pe: torch.Tensor, # value in unified attn
+ kv_cache: torch.Tensor,
+ attn_metadata: ROCMAiterMLASparseMetadata,
+ output: torch.Tensor | None = None,
+ output_scale: torch.Tensor | None = None,
+ output_block_scale: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
+ # MQA 576/512 approach for both prefill and decode
+
+ assert output is not None, "Output tensor must be provided."
+
+ if output_scale is not None or output_block_scale is not None:
+ raise NotImplementedError(
+ "fused output quantization is not yet supported for ROCMAiterMLASparse"
+ )
+
+ if attn_metadata is None:
+ # The zero fill is required when used with DP + EP
+ # to ensure all ranks within a DP group compute the
+ # same expert outputs.
+ return output.fill_(0)
+
+ num_actual_toks = attn_metadata.num_actual_tokens
+
+ # Inputs and outputs may be padded for CUDA graphs
+
+ q = q[:num_actual_toks, ...]
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
+ k_pe = k_pe[:num_actual_toks, ...]
+
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ # Convert from (B, N, P) to (N, B, P)
+ q_nope = q_nope.transpose(0, 1)
+ if self.is_fp8bmm_enabled:
+ # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
+ ql_nope = rocm_aiter_ops.triton_fp8_bmm(
+ q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
+ )
+ else:
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
+ # Convert from (N, B, L) to (B, N, L)
+ ql_nope = ql_nope.transpose(0, 1)
+
+ topk_indices = self.topk_indices_buffer[:num_actual_toks]
+
+ topk_indices_global = triton_convert_req_index_to_global_index(
+ attn_metadata.req_id_per_token,
+ attn_metadata.block_table,
+ topk_indices,
+ BLOCK_SIZE=attn_metadata.block_size,
+ NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
+ )
+
+ q = torch.cat([ql_nope, q_pe], dim=-1)
+
+ # write the latent and rope to kv cache
+ if kv_cache.numel() > 0:
+ ops.concat_and_cache_mla(
+ k_c_normed,
+ k_pe.squeeze(1),
+ kv_cache,
+ attn_metadata.slot_mapping.flatten(),
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=layer._k_scale,
+ )
+
+ attn_out = self._forward_bf16_kv(
+ q, kv_cache, topk_indices_global, attn_metadata
+ )
+
+ self._v_up_proj(attn_out, out=output[:num_actual_toks])
+ return output
diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py
index ea611848b0e81..ea911af3d19ce 100644
--- a/vllm/v1/attention/backends/rocm_aiter_fa.py
+++ b/vllm/v1/attention/backends/rocm_aiter_fa.py
@@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
+
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
@@ -514,12 +517,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- if attn_type != AttentionType.DECODER:
+ if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "FlashAttentionImpl"
+ "Encoder self-attention is not implemented for FlashAttentionImpl"
)
def extend_forward(
@@ -675,7 +675,14 @@ class AiterFlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
+ # key and value may be None in the case of cross attention. They are
+ # calculated once based on the output from the encoder and then cached
+ # in KV cache.
+ if (
+ self.kv_sharing_target_layer_name is None
+ and key is not None
+ and value is not None
+ ):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
@@ -701,8 +708,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
# decode:extend:prefill
query = query[:num_actual_tokens]
- key = key[:num_actual_tokens]
- value = value[:num_actual_tokens]
+ if key is not None:
+ key = key[:num_actual_tokens]
+ if value is not None:
+ value = value[:num_actual_tokens]
output_actual_tokens = output[:num_actual_tokens]
diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
index b2639c0df0412..16fb52ab501c1 100644
--- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py
@@ -142,7 +142,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
+ # key and value may be None in the case of cross attention. They are
+ # calculated once based on the output from the encoder and then cached
+ # in KV cache.
+ if (
+ self.kv_sharing_target_layer_name is None
+ and key is not None
+ and value is not None
+ ):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
@@ -169,7 +176,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
- descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
+ descale_shape = (
+ cu_seqlens_q.shape[0] - 1,
+ key.shape[1] if key is not None else self.num_kv_heads,
+ )
self.unified_attention(
q=query[:num_actual_tokens],
diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py
index 6dfdfc19ccba1..868143cc192e7 100644
--- a/vllm/v1/attention/backends/rocm_attn.py
+++ b/vllm/v1/attention/backends/rocm_attn.py
@@ -238,12 +238,9 @@ class RocmAttentionImpl(AttentionImpl):
RocmAttentionBackend.validate_head_size(head_size)
- if attn_type != AttentionType.DECODER:
+ if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "RocmAttentionImpl"
+ "Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py
index 1bf38ed225a4c..523f759e05a21 100644
--- a/vllm/v1/attention/backends/tree_attn.py
+++ b/vllm/v1/attention/backends/tree_attn.py
@@ -31,7 +31,10 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
+
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py
index 889c79db18ef5..d051a89f03bb4 100644
--- a/vllm/v1/attention/backends/triton_attn.py
+++ b/vllm/v1/attention/backends/triton_attn.py
@@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
@@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
"fp8_e5m2",
]
+ @staticmethod
+ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
+ return [MultipleOf(16)]
+
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
@@ -244,14 +247,11 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- if attn_type != AttentionType.DECODER:
+ if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "TritonAttentionImpl"
+ "Encoder self-attention is not implemented for TritonAttentionImpl"
)
-
+ self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
@@ -312,7 +312,11 @@ class TritonAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(1)
- if self.kv_sharing_target_layer_name is None:
+ if (
+ self.kv_sharing_target_layer_name is None
+ and key is not None
+ and value is not None
+ ):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
@@ -346,7 +350,7 @@ class TritonAttentionImpl(AttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
- descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
+ descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
unified_attention(
q=query[:num_actual_tokens],
diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py
index 578153cda7863..540a8e2b1d016 100644
--- a/vllm/v1/attention/backends/utils.py
+++ b/vllm/v1/attention/backends/utils.py
@@ -92,6 +92,7 @@ class CommonAttentionMetadata:
encoder_seq_lens: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
+ dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
@@ -1079,9 +1080,9 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
def get_dcp_local_seq_lens(
seq_lens: torch.Tensor,
- dcp_world_size: int = 1,
+ dcp_size: int = 1,
dcp_rank: int | None = None,
- dcp_kv_cache_interleave_size: int = 1,
+ cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each dcp rank.
@@ -1090,7 +1091,7 @@ def get_dcp_local_seq_lens(
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
- torch.arange(dcp_world_size, dtype=torch.int32)
+ torch.arange(dcp_size, dtype=torch.int32)
.unsqueeze(0)
.repeat(num_requests, 1)
)
@@ -1101,15 +1102,15 @@ def get_dcp_local_seq_lens(
)
base = (
seq_lens_tiled
- // dcp_kv_cache_interleave_size
- // dcp_world_size
- * dcp_kv_cache_interleave_size
+ // cp_kv_cache_interleave_size
+ // dcp_size
+ * cp_kv_cache_interleave_size
)
- remainder = seq_lens_tiled - base * dcp_world_size
+ remainder = seq_lens_tiled - base * dcp_size
remainder = torch.clip(
- remainder - rank_offsets * dcp_kv_cache_interleave_size,
+ remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
- dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)
diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py
deleted file mode 100644
index d15d79417cc61..0000000000000
--- a/vllm/v1/attention/backends/xformers.py
+++ /dev/null
@@ -1,417 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Attention layer with XFormersAttention."""
-
-from dataclasses import dataclass
-from typing import ClassVar, Optional
-
-import torch
-
-from vllm.attention.backends.abstract import (
- AttentionBackend,
- AttentionImpl,
- AttentionType,
- MultipleOf,
-)
-from vllm.attention.ops.triton_unified_attention import unified_attention
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.v1.attention.backends.utils import (
- AttentionMetadataBuilder,
- CommonAttentionMetadata,
- split_decodes_and_prefills,
-)
-from vllm.v1.kv_cache_interface import AttentionSpec
-
-try:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import (
- AttentionBias,
- PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
- )
-
- XFORMERS_AVAILABLE = True
-except ImportError:
- XFORMERS_AVAILABLE = False
-
-from vllm import _custom_ops as ops
-
-logger = init_logger(__name__)
-
-
-class XFormersAttentionBackend(AttentionBackend):
- accept_output_buffer: bool = True
- supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
- supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
-
- @classmethod
- def get_supported_head_sizes(cls) -> list[int]:
- return [
- 32,
- 40,
- 48,
- 56,
- 64,
- 72,
- 80,
- 88,
- 96,
- 104,
- 112,
- 120,
- 128,
- 136,
- 144,
- 152,
- 160,
- 168,
- 176,
- 184,
- 192,
- 200,
- 208,
- 216,
- 224,
- 232,
- 240,
- 248,
- 256,
- ]
-
- @staticmethod
- def get_name() -> str:
- return "XFORMERS"
-
- @staticmethod
- def get_impl_cls() -> type["XFormersAttentionImpl"]:
- return XFormersAttentionImpl
-
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- cache_dtype_str: str = "auto",
- ) -> tuple[int, ...]:
- if block_size % 16 != 0:
- raise ValueError("Block size must be a multiple of 16.")
- return (2, num_blocks, block_size, num_kv_heads, head_size)
-
- @staticmethod
- def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
- return XFormersAttentionMetadataBuilder
-
- @staticmethod
- def use_cascade_attention(*args, **kwargs) -> bool:
- return False
-
-
-@dataclass
-class XFormersAttentionMetadata:
- num_actual_tokens: int # Number of tokens excluding padding.
- max_query_len: int
- query_start_loc: torch.Tensor
- max_seq_len: int
- seq_lens: torch.Tensor
- block_table: torch.Tensor
- slot_mapping: torch.Tensor
-
- num_prefill_tokens: int = 0
- num_decode_tokens: int = 0
- num_prefills: int = 0
- num_decodes: int = 0
-
- # Biases for different attention types.
- attn_bias: Optional["AttentionBias"] = None
-
- # Self-attention prefill/decode metadata cache
- _cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
- _cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
-
- @property
- def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
- if self.num_prefills == 0:
- return None
-
- if self._cached_prefill_metadata is not None:
- # Recover cached prefill-phase attention
- # metadata structure
- return self._cached_prefill_metadata
-
- q_start_loc = self.query_start_loc[self.num_decodes :]
- q_seqlens = torch.diff(q_start_loc)
- kv_seqlens = self.seq_lens[self.num_decodes :]
- # Construct & cache prefill-phase attention metadata structure
- self._cached_prefill_metadata = XFormersAttentionMetadata(
- num_actual_tokens=self.num_prefill_tokens,
- max_query_len=int(q_seqlens.max().item()),
- query_start_loc=q_start_loc - q_start_loc[0],
- max_seq_len=int(kv_seqlens.max().item()),
- seq_lens=kv_seqlens,
- block_table=self.block_table[self.num_decodes :],
- slot_mapping=self.slot_mapping[self.num_decode_tokens :],
- )
- return self._cached_prefill_metadata
-
- @property
- def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
- if self.num_decode_tokens == 0:
- return None
-
- if self._cached_decode_metadata is not None:
- # Recover cached decode-phase attention
- # metadata structure
- return self._cached_decode_metadata
-
- q_start_loc = self.query_start_loc
- q_seqlens = torch.diff(q_start_loc)
- decode_kv_seqlens = self.seq_lens[: self.num_decodes]
- # Construct & cache decode-phase attention metadata structure
- self._cached_decode_metadata = XFormersAttentionMetadata(
- num_actual_tokens=self.num_decode_tokens,
- max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
- query_start_loc=q_start_loc[: self.num_decodes + 1],
- max_seq_len=int(decode_kv_seqlens.max().item()),
- seq_lens=decode_kv_seqlens,
- block_table=self.block_table[: self.num_decodes],
- slot_mapping=self.slot_mapping[: self.num_decode_tokens],
- attn_bias=self.attn_bias,
- )
- return self._cached_decode_metadata
-
-
-class XFormersAttentionMetadataBuilder(
- AttentionMetadataBuilder[XFormersAttentionMetadata]
-):
- reorder_batch_threshold: int = 1
-
- def __init__(
- self,
- kv_cache_spec: AttentionSpec,
- layer_names: list[str],
- vllm_config: VllmConfig,
- device: torch.device,
- ):
- super().__init__(kv_cache_spec, layer_names, vllm_config, device)
-
- assert XFORMERS_AVAILABLE
- self.block_size = kv_cache_spec.block_size
- self._num_decodes = 0
- self._num_decode_tokens = 0
-
- def build(
- self,
- common_prefix_len: int,
- common_attn_metadata: CommonAttentionMetadata,
- fast_build: bool = False,
- ) -> XFormersAttentionMetadata:
- num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
- split_decodes_and_prefills(
- common_attn_metadata, decode_threshold=self.reorder_batch_threshold
- )
- )
-
- num_actual_tokens = common_attn_metadata.num_actual_tokens
- q_start_loc = common_attn_metadata.query_start_loc
- q_seqlens = torch.diff(q_start_loc)
- max_query_len = common_attn_metadata.max_query_len
- kv_seqlens = common_attn_metadata.seq_lens
- max_seq_len = common_attn_metadata.max_seq_len
- block_table = common_attn_metadata.block_table_tensor
- slot_mapping = common_attn_metadata.slot_mapping
-
- bias = None
- if num_decodes > 0:
- # Construct the decoder bias.
- decode_q_seqlens = q_seqlens[:num_decodes]
- decode_kv_seqlens = kv_seqlens[:num_decodes]
- bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
- q_seqlen=decode_q_seqlens.tolist(),
- kv_seqlen=decode_kv_seqlens.tolist(),
- page_size=self.block_size,
- block_tables=block_table[:num_decodes],
- device=block_table.device,
- )
-
- return XFormersAttentionMetadata(
- num_actual_tokens=num_actual_tokens,
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- num_prefills=num_prefills,
- num_decodes=num_decodes,
- max_query_len=max_query_len,
- query_start_loc=q_start_loc,
- max_seq_len=max_seq_len,
- seq_lens=kv_seqlens,
- block_table=block_table,
- slot_mapping=slot_mapping,
- attn_bias=bias,
- )
-
-
-class XFormersAttentionImpl(AttentionImpl):
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: int,
- alibi_slopes: list[float] | None,
- sliding_window: int | None,
- kv_cache_dtype: str,
- logits_soft_cap: float | None = None,
- attn_type: AttentionType = AttentionType.DECODER,
- kv_sharing_target_layer_name: str | None = None,
- ) -> None:
- if kv_sharing_target_layer_name is not None:
- raise NotImplementedError("KV sharing is not supported in V0.")
- if alibi_slopes is not None:
- raise NotImplementedError("XFormers does not support alibi slopes yet.")
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_kv_heads
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- self.kv_cache_dtype = kv_cache_dtype
- self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
- if alibi_slopes is not None:
- alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
- self.alibi_slopes = alibi_slopes
- if sliding_window is None:
- self.sliding_window = (-1, -1)
- else:
- self.sliding_window = (sliding_window - 1, 0)
- if logits_soft_cap is None:
- # Setting logits_soft_cap to 0 means no soft cap.
- logits_soft_cap = 0
- self.logits_soft_cap = logits_soft_cap
-
- if attn_type != AttentionType.DECODER:
- raise NotImplementedError(
- "Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "XFormersAttentionImpl."
- )
-
- def forward(
- self,
- layer: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: XFormersAttentionMetadata,
- output: torch.Tensor | None = None,
- output_scale: torch.Tensor | None = None,
- output_block_scale: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Forward pass with XFormers.
-
- Args:
- query: shape = [num_tokens, num_heads, head_size]
- key: shape = [num_tokens, num_kv_heads, head_size]
- value: shape = [num_tokens, num_kv_heads, head_size]
- kv_cache: shape =
- [2, num_blocks, block_size, num_kv_heads, head_size]
- attn_metadata: Metadata for attention.
- Returns:
- shape = [num_tokens, num_heads * head_size]
- """
- assert output is not None, "Output tensor must be provided."
-
- if output_scale is not None or output_block_scale is not None:
- raise NotImplementedError(
- "fused output quantization is not yet supported"
- " for XFormersAttentionImpl"
- )
-
- if attn_metadata is None:
- # Profiling run.
- return output.fill_(0)
-
- # Cache the input KVs.
- key_cache, value_cache = kv_cache.unbind(0)
- if self.kv_sharing_target_layer_name is None:
- # Reshape the input keys and values and store them in the cache.
- # Skip this if sharing KV cache with an earlier attention layer.
- # NOTE(woosuk): Here, key and value are padded while slot_mapping is
- # not padded. However, we don't need to do key[:num_actual_tokens]
- # and value[:num_actual_tokens] because the reshape_and_cache_flash
- # op uses the slot_mapping's shape to determine the number of
- # actual tokens.
- ops.reshape_and_cache_flash(
- key,
- value,
- key_cache,
- value_cache,
- attn_metadata.slot_mapping,
- self.kv_cache_dtype,
- layer._k_scale,
- layer._v_scale,
- )
-
- num_actual_tokens = attn_metadata.num_actual_tokens
- num_decode_tokens = attn_metadata.num_decode_tokens
- if prefill_meta := attn_metadata.prefill_metadata:
- descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
- unified_attention(
- q=query[num_decode_tokens:num_actual_tokens],
- k=key_cache,
- v=value_cache,
- out=output[num_decode_tokens:num_actual_tokens],
- cu_seqlens_q=prefill_meta.query_start_loc,
- max_seqlen_q=prefill_meta.max_query_len,
- seqused_k=prefill_meta.seq_lens,
- max_seqlen_k=prefill_meta.max_seq_len,
- softmax_scale=self.scale,
- causal=True,
- alibi_slopes=self.alibi_slopes,
- window_size=self.sliding_window,
- block_table=prefill_meta.block_table,
- softcap=self.logits_soft_cap,
- q_descale=None, # Not supported
- k_descale=layer._k_scale.expand(descale_shape),
- v_descale=layer._v_scale.expand(descale_shape),
- )
-
- if decode_meta := attn_metadata.decode_metadata:
- # Query for decode. KV is not needed because it is already cached.
- decode_query = query[:num_decode_tokens]
- # Reshape query to [1, B_T, G, H, D].
- q = decode_query.view(
- 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
- )
- # Reshape the k and v caches to [1, Bkv_T, G, H, D]
- cache_k = key_cache.view(
- 1, -1, self.num_kv_heads, 1, self.head_size
- ).expand(
- 1,
- -1,
- self.num_kv_heads,
- self.num_queries_per_kv,
- self.head_size,
- )
- cache_v = value_cache.view(
- 1, -1, self.num_kv_heads, 1, self.head_size
- ).expand(
- 1,
- -1,
- self.num_kv_heads,
- self.num_queries_per_kv,
- self.head_size,
- )
-
- attn_bias = decode_meta.attn_bias
- output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
- q,
- cache_k,
- cache_v,
- attn_bias=attn_bias,
- p=0.0,
- scale=self.scale,
- ).view(decode_query.shape)
-
- # Reshape the output tensor.
- return output
diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py
index 137e5e0cdb6d2..1531b61f88fe2 100644
--- a/vllm/v1/core/kv_cache_coordinator.py
+++ b/vllm/v1/core/kv_cache_coordinator.py
@@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
+ pcp_world_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
@@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
block_pool=self.block_pool,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
)
@@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
use_eagle: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
+ pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
+ pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
+ self.pcp_world_size = pcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
+ if pcp_world_size > 1:
+ self.block_size *= pcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size,
+ pcp_world_size=self.pcp_world_size,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
@@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
+ pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
assert dcp_world_size == 1, "DCP not support hybrid attn now."
+ assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
@@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
+ pcp_world_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 7f405fc248ac2..2012c3fef88bc 100644
--- a/vllm/v1/core/kv_cache_manager.py
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -100,6 +100,7 @@ class KVCacheManager:
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> None:
self.max_model_len = max_model_len
@@ -124,12 +125,9 @@ class KVCacheManager:
0
].kv_cache_spec.block_size
- if dcp_world_size > 1:
+ if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
- # Note(hc): need revisit. When both DCP and any future
- # PCP are enabled, the block_size may need to be scaled
- # by a factor of dcp_size × pcp_size?
- self.block_size *= dcp_world_size
+ self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
@@ -138,6 +136,7 @@ class KVCacheManager:
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
+ pcp_world_size=pcp_world_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index 6e026215d4022..a0033fa650baa 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -971,7 +971,16 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
- group_size = min([len(layers) for layers in same_type_layers.values()])
+ min_num_layers = min([len(layers) for layers in same_type_layers.values()])
+ group_size = min_num_layers
+ max_num_layers = max([len(layers) for layers in same_type_layers.values()])
+ if max_num_layers < min_num_layers * 1.25:
+ # If the number of layers is not much larger than the minimum number of layers,
+ # use the maximum number of layers as the group size to avoid too many padding
+ # layers. A typical example is gpt-oss-20b + eagle, with 12 sw + 13 full. We
+ # pad it to (13 sw, 13 full) instead of (12 sw, 24 full). 1.25 is just a
+ # magic number to avoid too many padding layers.
+ group_size = max_num_layers
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
@@ -1219,11 +1228,16 @@ def _report_kv_cache_config(
// len(kv_cache_config.kv_cache_groups)
* min_block_size
)
- if vllm_config.parallel_config.decode_context_parallel_size > 1:
- num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
+ dcp_size = vllm_config.parallel_config.decode_context_parallel_size
+ pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
+ if pcp_size * dcp_size > 1:
+ num_tokens *= pcp_size * dcp_size
logger.info(
- "Multiplying the GPU KV cache size by the dcp_world_size %d.",
- vllm_config.parallel_config.decode_context_parallel_size,
+ "Multiplying the GPU KV cache size by the cp_world_size %d "
+ "(pcp_world_size %d * dcp_world_size %d).",
+ pcp_size * dcp_size,
+ pcp_size,
+ dcp_size,
)
num_tokens_str = f"{num_tokens:,}"
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
@@ -1231,10 +1245,11 @@ def _report_kv_cache_config(
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config
)
- logger.info(
+ logger.info_once(
"Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str,
max_concurrency,
+ scope="local",
)
diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
index 20fdb3446404b..7902513dce49a 100644
--- a/vllm/v1/core/sched/output.py
+++ b/vllm/v1/core/sched/output.py
@@ -44,11 +44,15 @@ class NewRequestData:
lora_request: LoRARequest | None
prompt_embeds: "torch.Tensor | None" = None
+ # Only used for v2 model runner.
+ prefill_token_ids: list[int] | None = None
+
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
+ prefill_token_ids: list[int] | None = None,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
@@ -60,6 +64,7 @@ class NewRequestData:
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
+ prefill_token_ids=prefill_token_ids,
)
def __repr__(self) -> str:
@@ -68,6 +73,7 @@ class NewRequestData:
f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
+ f"prefill_token_ids={self.prefill_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
@@ -183,6 +189,10 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
+ # Request IDs that are preempted in this step.
+ # Only used for v2 model runner.
+ preempted_req_ids: set[str] | None = None
+
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False
@@ -193,6 +203,20 @@ class SchedulerOutput:
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
+ @classmethod
+ def make_empty(cls) -> "SchedulerOutput":
+ return cls(
+ scheduled_new_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
+ num_scheduled_tokens={},
+ total_num_scheduled_tokens=0,
+ scheduled_spec_decode_tokens={},
+ scheduled_encoder_inputs={},
+ num_common_prefix_blocks=[],
+ finished_req_ids=set(),
+ free_encoder_mm_hashes=[],
+ )
+
@dataclass
class GrammarOutput:
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index e3260b3dae797..4ba7539b0eadb 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -6,6 +6,7 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Any
+from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
@@ -121,6 +122,7 @@ class Scheduler(SchedulerInterface):
self.block_size = block_size
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
+ self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# req_id -> Request
self.requests: dict[str, Request] = {}
@@ -178,11 +180,12 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
- enable_caching=bool(self.cache_config.enable_prefix_caching),
+ enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
+ pcp_world_size=self.pcp_world_size,
)
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
if sink_len > 0:
@@ -190,6 +193,7 @@ class Scheduler(SchedulerInterface):
num_sink_block = sink_len // self.block_size
self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
+ self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@@ -471,6 +475,7 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request)
continue
+ request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external).
@@ -508,9 +513,9 @@ class Scheduler(SchedulerInterface):
not self.scheduler_config.enable_chunked_prefill
and num_new_tokens > token_budget
):
- self.waiting.pop_request()
- skipped_waiting_requests.prepend_request(request)
- continue
+ # If chunked_prefill is disabled,
+ # we can stop the scheduling here.
+ break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@@ -577,9 +582,6 @@ class Scheduler(SchedulerInterface):
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
- self._update_connector_prefix_cache_stats(
- request, num_external_computed_tokens
- )
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
@@ -591,6 +593,8 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
+ self._update_connector_prefix_cache_stats(request)
+
req_index += 1
self.running.append(request)
if self.log_stats:
@@ -661,12 +665,25 @@ class Scheduler(SchedulerInterface):
)
# Construct the scheduler output.
- new_reqs_data = [
- NewRequestData.from_request(
- req, req_to_new_blocks[req.request_id].get_block_ids()
- )
- for req in scheduled_new_reqs
- ]
+ if self.use_v2_model_runner:
+ scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
+ scheduled_resumed_reqs = []
+ new_reqs_data = [
+ NewRequestData.from_request(
+ req,
+ req_to_new_blocks[req.request_id].get_block_ids(),
+ req._all_token_ids,
+ )
+ for req in scheduled_new_reqs
+ ]
+ else:
+ new_reqs_data = [
+ NewRequestData.from_request(
+ req, req_to_new_blocks[req.request_id].get_block_ids()
+ )
+ for req in scheduled_new_reqs
+ ]
+
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
@@ -688,6 +705,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
+ preempted_req_ids={req.request_id for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
@@ -1016,8 +1034,8 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
- generated_token_ids: list[int] = (
- sampled_token_ids[req_index].tolist() if sampled_token_ids else []
+ generated_token_ids = (
+ sampled_token_ids[req_index] if sampled_token_ids else []
)
scheduled_spec_token_ids = (
@@ -1367,15 +1385,13 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods
########################################################################
- def _update_connector_prefix_cache_stats(
- self, request: Request, num_external_tokens: int
- ) -> None:
+ def _update_connector_prefix_cache_stats(self, request: Request) -> None:
if self.connector_prefix_cache_stats is None:
return
self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens,
- num_hits=num_external_tokens,
+ num_hits=request.num_external_computed_tokens,
preempted=request.num_preemptions > 0,
)
@@ -1558,9 +1574,11 @@ class Scheduler(SchedulerInterface):
marked_invalid_block = True
# Truncate the computed tokens at the first failed block
request.num_computed_tokens = idx * self.block_size
- total_affected_tokens += (
+ num_affected_tokens = (
req_num_computed_tokens - request.num_computed_tokens
)
+ total_affected_tokens += num_affected_tokens
+ request.num_external_computed_tokens -= num_affected_tokens
if is_affected:
if not marked_invalid_block:
diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
index 5aba227fa0176..ee5ae21d02843 100644
--- a/vllm/v1/core/single_type_kv_cache_manager.py
+++ b/vllm/v1/core/single_type_kv_cache_manager.py
@@ -33,6 +33,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
@@ -43,8 +44,9 @@ class SingleTypeKVCacheManager(ABC):
"""
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
- if self.dcp_world_size > 1:
- self.block_size *= dcp_world_size
+ self.pcp_world_size = pcp_world_size
+ if dcp_world_size * pcp_world_size > 1:
+ self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@@ -213,6 +215,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
@@ -304,6 +307,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec,
@@ -316,8 +320,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
- if dcp_world_size > 1:
- block_size *= dcp_world_size
+ if dcp_world_size * pcp_world_size > 1:
+ block_size *= dcp_world_size * pcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not
@@ -364,11 +368,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups"
)
assert dcp_world_size == 1, "DCP not support sliding window attn now."
+ assert pcp_world_size == 1, "PCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
@@ -478,6 +484,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
@@ -518,6 +525,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
+ assert pcp_world_size == 1, "PCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@@ -613,11 +621,13 @@ class MambaManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, MambaSpec), (
"MambaManager can only be used for mamba groups"
)
assert dcp_world_size == 1, "DCP not support mamba now."
+ assert pcp_world_size == 1, "PCP not support mamba now."
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))
)
@@ -707,6 +717,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
+ pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"
diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py
index 3f621d77c0241..ce2aae77108da 100644
--- a/vllm/v1/engine/__init__.py
+++ b/vllm/v1/engine/__init__.py
@@ -72,6 +72,14 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None
+ @property
+ def params(self) -> SamplingParams | PoolingParams:
+ """Return the processed params (sampling or pooling)."""
+ if self.sampling_params is not None:
+ return self.sampling_params
+ assert self.pooling_params is not None
+ return self.pooling_params
+
class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py
index c160c7cbcab4a..55087baadff97 100644
--- a/vllm/v1/engine/async_llm.py
+++ b/vllm/v1/engine/async_llm.py
@@ -152,6 +152,10 @@ class AsyncLLM(EngineClient):
)
self.logger_manager.log_engine_initialized()
+ # Pause / resume state for async RL workflows.
+ self._pause_cond = asyncio.Condition()
+ self._paused = False
+
self.output_handler: asyncio.Task | None = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
@@ -160,11 +164,23 @@ class AsyncLLM(EngineClient):
except RuntimeError:
pass
- if envs.VLLM_TORCH_PROFILER_DIR:
+ if (
+ envs.VLLM_TORCH_PROFILER_DIR
+ and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
+ ):
logger.info(
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
envs.VLLM_TORCH_PROFILER_DIR,
)
+ if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
+ logger.warning_once(
+ "Torch profiler received max_iters or delay_iters setting. These "
+ "are not compatible with the AsyncLLM profiler and will be ignored "
+ "for the AsyncLLM process. Engine process profiling will still "
+ "respect these settings. Consider setting "
+ "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
+ "AsyncLLM profiling."
+ )
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
self.profiler = torch.profiler.profile(
activities=[
@@ -305,14 +321,15 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
+ # Use cloned params that may have been updated in process_inputs()
+ params = request.params
+
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
return queue
- # Get the updated SamplingParams from the request, which
- # were cloned/updated in processor.process_inputs above.
- parent_params = request.sampling_params
- assert parent_params is not None
+ parent_params = params
+ assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)
@@ -392,6 +409,10 @@ class AsyncLLM(EngineClient):
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
+ # Wait until generation is resumed if the engine is paused.
+ async with self._pause_cond:
+ await self._pause_cond.wait_for(lambda: not self._paused)
+
if tokenization_kwargs is None:
tokenization_kwargs = {}
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
@@ -539,6 +560,58 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request(s) %s.", ",".join(request_ids))
+ async def pause_generation(
+ self,
+ *,
+ wait_for_inflight_requests: bool = False,
+ clear_cache: bool = True,
+ ) -> None:
+ """
+ Pause generation to allow model weight updates.
+
+ New generation/encoding requests are blocked until resume.
+
+ Args:
+ wait_for_inflight_requests: When ``True`` waits for in-flight
+ requests to finish before pausing. When ``False`` (default),
+ immediately aborts any in-flight requests.
+ clear_cache: Whether to clear KV cache and prefix cache after
+ draining. Set to ``False`` to preserve cache for faster resume.
+ Default is ``True`` (clear caches).
+ """
+
+ async with self._pause_cond:
+ if self._paused:
+ return
+ self._paused = True
+
+ if not wait_for_inflight_requests:
+ request_ids = list(self.output_processor.request_states.keys())
+ if request_ids:
+ await self.abort(request_ids)
+
+ # Wait for running requests to drain before clearing cache.
+ if self.output_processor.has_unfinished_requests():
+ await self.output_processor.wait_for_requests_to_drain()
+
+ # Clear cache
+ if clear_cache:
+ await self.reset_prefix_cache()
+ await self.reset_mm_cache()
+
+ async def resume_generation(self) -> None:
+ """Resume generation after :meth:`pause_generation`."""
+
+ async with self._pause_cond:
+ self._paused = False
+ self._pause_cond.notify_all() # Wake up all waiting requests
+
+ async def is_paused(self) -> bool:
+ """Return whether the engine is currently paused."""
+
+ async with self._pause_cond:
+ return self._paused
+
async def encode(
self,
prompt: PromptType,
@@ -570,6 +643,10 @@ class AsyncLLM(EngineClient):
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
+ # Respect pause state before accepting new requests.
+ async with self._pause_cond:
+ await self._pause_cond.wait_for(lambda: not self._paused)
+
if tokenization_kwargs is None:
tokenization_kwargs = {}
_validate_truncation_size(
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index 3a25827cec385..8657a95b5e6e7 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -128,6 +128,7 @@ class EngineCore:
scheduler_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
+ * vllm_config.parallel_config.prefill_context_parallel_size
)
self.scheduler: SchedulerInterface = Scheduler(
@@ -205,6 +206,8 @@ class EngineCore:
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
+ # If enable, attach GC debugger after static variable freeze.
+ maybe_attach_gc_debug_callback()
def _initialize_kv_caches(
self, vllm_config: VllmConfig
@@ -644,9 +647,6 @@ class EngineCoreProc(EngineCore):
assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...")
- # If enable, attach GC debugger after static variable freeze.
- maybe_attach_gc_debug_callback()
-
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache()
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index e403cea87788b..dffe05445ee46 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -250,6 +250,9 @@ class LLMEngine:
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
+ # Use cloned params that may have been updated in process_inputs()
+ params = request.params
+
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
@@ -262,10 +265,10 @@ class LLMEngine:
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
- request_id, params = parent_req.get_child_info(idx)
+ request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
- child_request.sampling_params = params
+ child_request.sampling_params = child_params
# Make a new RequestState and queue.
self.output_processor.add_request(
diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py
index bdbbfe2595f81..0453c4a77f0cd 100644
--- a/vllm/v1/engine/output_processor.py
+++ b/vllm/v1/engine/output_processor.py
@@ -350,6 +350,8 @@ class OutputProcessor:
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None
+ self._requests_drained = asyncio.Event()
+ self._requests_drained.set()
def get_num_unfinished_requests(self):
return len(self.request_states)
@@ -357,6 +359,11 @@ class OutputProcessor:
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
+ async def wait_for_requests_to_drain(self) -> None:
+ if not self.request_states:
+ return
+ await self._requests_drained.wait()
+
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
@@ -396,6 +403,8 @@ class OutputProcessor:
child_reqs = self.abort_requests(child_reqs)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
+ if not self.request_states:
+ self._requests_drained.set()
return request_ids_to_abort
def add_request(
@@ -420,6 +429,8 @@ class OutputProcessor:
log_stats=self.log_stats,
stream_interval=self.stream_interval,
)
+ if self._requests_drained.is_set():
+ self._requests_drained.clear()
self.request_states[request_id] = req_state
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
@@ -511,6 +522,8 @@ class OutputProcessor:
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
+ if not self.request_states:
+ self._requests_drained.set()
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 4cb911d8e22b7..af4f0e410e253 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -20,6 +20,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -142,9 +143,6 @@ class Processor:
self,
params: SamplingParams,
) -> None:
- # Best of not yet supported.
- if params.best_of is not None and params.best_of > 1:
- raise ValueError("vLLM V1 does not yet support best_of.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError(
@@ -303,12 +301,24 @@ class Processor:
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
+ if isinstance(self.tokenizer, MistralTokenizer):
+ raise ValueError(
+ "Mistral tokenizer is not supported for the 'guidance' "
+ "structured output backend. Please use ['xgrammar', 'outlines'] "
+ "backends or tokenizer_mode='hf' instead."
+ )
validate_guidance_grammar(params, tokenizer=None)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
+ if isinstance(self.tokenizer, MistralTokenizer):
+ raise ValueError(
+ "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
+ "structured output backend. Please use ['xgrammar', 'outlines'] "
+ "backends or tokenizer_mode='hf' instead."
+ )
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: backend must be "auto" here, because we have
@@ -323,9 +333,15 @@ class Processor:
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
- # are not supported in xgrammar. Fall back to guidance.
- validate_guidance_grammar(params, tokenizer=None)
- params.structured_outputs._backend = "guidance"
+ # are not supported in xgrammar.
+ if isinstance(self.tokenizer, MistralTokenizer):
+ # Fall back to outlines if the tokenizer is Mistral
+ validate_structured_output_request_outlines(params)
+ params.structured_outputs._backend = "outlines"
+ else:
+ # Fall back to guidance by default.
+ validate_guidance_grammar(params, tokenizer=None)
+ params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
params.structured_outputs._backend_was_auto = True
diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py
index ad2ece50f9815..7e8ebe25c4603 100644
--- a/vllm/v1/executor/multiproc_executor.py
+++ b/vllm/v1/executor/multiproc_executor.py
@@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
get_inner_dp_world_group,
+ get_pcp_group,
get_pp_group,
get_tp_group,
)
@@ -110,12 +111,14 @@ class MultiprocExecutor(Executor):
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
- tensor_parallel_size = self.parallel_config.tensor_parallel_size
- pp_parallel_size = self.parallel_config.pipeline_parallel_size
- assert self.world_size == tensor_parallel_size * pp_parallel_size, (
+ tp_size = self.parallel_config.tensor_parallel_size
+ pp_size = self.parallel_config.pipeline_parallel_size
+ pcp_size = self.parallel_config.prefill_context_parallel_size
+ assert self.world_size == tp_size * pp_size * pcp_size, (
f"world_size ({self.world_size}) must be equal to the "
- f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
- f"_parallel_size ({pp_parallel_size}). "
+ f"tensor_parallel_size ({tp_size}) x pipeline"
+ f"_parallel_size ({pp_size}) x prefill_context"
+ f"_parallel_size ({pcp_size}). "
)
# Set multiprocessing envs
@@ -424,7 +427,11 @@ class MultiprocExecutor(Executor):
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
- return self.world_size - self.parallel_config.tensor_parallel_size
+ return (
+ self.world_size
+ - self.parallel_config.tensor_parallel_size
+ * self.parallel_config.prefill_context_parallel_size
+ )
@dataclass
@@ -828,6 +835,8 @@ class WorkerProc:
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size
pp_rank = get_pp_group().rank_in_group
+ pcp_size = get_pcp_group().world_size
+ pcp_rank = get_pcp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
@@ -837,6 +846,8 @@ class WorkerProc:
process_name += f"_DP{dp_rank}"
if pp_size > 1:
process_name += f"_PP{pp_rank}"
+ if pcp_size > 1:
+ process_name += f"_PCP{pcp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if dcp_size > 1:
diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py
index 5ff99acac52a2..aa3ca82a5d4a3 100644
--- a/vllm/v1/kv_cache_interface.py
+++ b/vllm/v1/kv_cache_interface.py
@@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
+ pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
- if dcp_world_size > 1:
- max_model_len = cdiv(max_model_len, dcp_world_size)
+ if dcp_world_size * pcp_world_size > 1:
+ max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py
index 4b1bbe6f0cc2a..86747299eb107 100644
--- a/vllm/v1/kv_offload/cpu.py
+++ b/vllm/v1/kv_offload/cpu.py
@@ -4,8 +4,8 @@ from collections.abc import Iterator
import torch
-from vllm.config import VllmConfig, get_layers_from_vllm_config
-from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
+from vllm.attention import AttentionBackend
+from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
@@ -63,7 +63,9 @@ class CPUOffloadingSpec(OffloadingSpec):
return self._manager
def get_handlers(
- self, kv_caches: dict[str, torch.Tensor]
+ self,
+ kv_caches: dict[str, torch.Tensor],
+ attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
if not self._handler:
if not current_platform.is_cuda_alike():
@@ -71,15 +73,6 @@ class CPUOffloadingSpec(OffloadingSpec):
"CPU Offloading is currently only supported on CUDA-alike GPUs"
)
- layer_names = list(kv_caches.keys())
- layers = get_layers_from_vllm_config(
- self.vllm_config, AttentionLayerBase, layer_names
- )
- attn_backends = {
- layer_name: layers[layer_name].get_attn_backend()
- for layer_name in layer_names
- }
-
self._handler = CpuGpuOffloadingHandler(
attn_backends=attn_backends,
gpu_block_size=self.gpu_block_size,
diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py
index a3c539a47d458..c1813a4ff4ea9 100644
--- a/vllm/v1/kv_offload/spec.py
+++ b/vllm/v1/kv_offload/spec.py
@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING:
+ from vllm.attention import AttentionBackend
from vllm.config import VllmConfig
logger = init_logger(__name__)
@@ -48,13 +49,16 @@ class OffloadingSpec(ABC):
@abstractmethod
def get_handlers(
- self, kv_caches: dict[str, torch.Tensor]
+ self,
+ kv_caches: dict[str, torch.Tensor],
+ attn_backends: dict[str, type["AttentionBackend"]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
"""
Get offloading handlers along with their respective src and dst types.
Args:
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
+ attn_backends: A dictionary of layer_name -> AttentionBackend.
Yields:
Tuples of (src_type, dst_type, offloading_handler).
diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py
index 0f2ec4a1b41f3..bb163f0043fc6 100644
--- a/vllm/v1/kv_offload/worker/cpu_gpu.py
+++ b/vllm/v1/kv_offload/worker/cpu_gpu.py
@@ -83,10 +83,18 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
self.gpu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor.shape
- test_shape = attn_backends[layer_name].get_kv_cache_shape(
+ attn_backend = attn_backends[layer_name]
+ test_shape = attn_backend.get_kv_cache_shape(
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
)
- if test_shape[0] == 1234:
+
+ if len(gpu_shape) != len(test_shape):
+ # cross-layers tensor
+ # shape is (num_blocks, ...)
+ assert len(gpu_shape) == len(test_shape) + 1
+ num_blocks_idx = 0
+ self.kv_dim_before_num_blocks.append(False)
+ elif test_shape[0] == 1234:
# shape is (num_blocks, ...)
num_blocks_idx = 0
self.kv_dim_before_num_blocks.append(False)
@@ -135,22 +143,20 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
assert src_blocks.ndim == 1
assert dst_blocks.ndim == 1
- dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
src_sub_block_count = src_blocks.size * src_block_size_factor
+ dst_sub_block_count = dst_blocks.size * dst_block_size_factor
+ src_sub_blocks_to_skip = -dst_blocks.size % src_block_size_factor
- assert (
- src_sub_block_count
- == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip
- )
+ assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
- src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
- expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
+ src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
expand_block_ids(
- dst_blocks,
- dst_block_size_factor,
- src_to_dst[:, 1],
- skip_count=dst_sub_blocks_to_skip,
+ src_blocks,
+ src_block_size_factor,
+ src_to_dst[:, 0],
+ skip_count=src_sub_blocks_to_skip,
)
+ expand_block_ids(dst_blocks, dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop() if self.events_pool else torch.Event()
diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py
index cb36e7973650e..bd18a152ffc08 100644
--- a/vllm/v1/metrics/loggers.py
+++ b/vllm/v1/metrics/loggers.py
@@ -16,7 +16,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPrometheus,
)
from vllm.logger import init_logger
-from vllm.plugins import load_plugins_by_group
+from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import (
@@ -67,7 +67,7 @@ class StatLoggerBase(ABC):
def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
factories: list[StatLoggerFactory] = []
- for name, plugin_class in load_plugins_by_group("vllm.stat_logger_plugins").items():
+ for name, plugin_class in load_plugins_by_group(STAT_LOGGER_PLUGINS_GROUP).items():
if not isinstance(plugin_class, type) or not issubclass(
plugin_class, StatLoggerBase
):
@@ -440,57 +440,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
# Setting default values
self.record_sleep_state()
- # GPU cache
- #
- # Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- gauge_gpu_cache_usage = self._gauge_cls(
- name="vllm:gpu_cache_usage_perc",
- documentation=(
- "GPU KV-cache usage. 1 means 100 percent usage."
- "DEPRECATED: Use vllm:kv_cache_usage_perc instead."
- ),
- multiprocess_mode="mostrecent",
- labelnames=labelnames,
- )
- self.gauge_gpu_cache_usage = make_per_engine(
- gauge_gpu_cache_usage, engine_indexes, model_name
- )
-
- # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- counter_gpu_prefix_cache_queries = self._counter_cls(
- name="vllm:gpu_prefix_cache_queries",
- documentation=(
- "GPU prefix cache queries, in terms of number of queried"
- "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."
- ),
- labelnames=labelnames,
- )
- self.counter_gpu_prefix_cache_queries = make_per_engine(
- counter_gpu_prefix_cache_queries, engine_indexes, model_name
- )
-
- # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits
- # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
- # TODO: remove in 0.12.0
- if self.show_hidden_metrics:
- counter_gpu_prefix_cache_hits = self._counter_cls(
- name="vllm:gpu_prefix_cache_hits",
- documentation=(
- "GPU prefix cache hits, in terms of number of cached "
- "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."
- ),
- labelnames=labelnames,
- )
- self.counter_gpu_prefix_cache_hits = make_per_engine(
- counter_gpu_prefix_cache_hits, engine_indexes, model_name
- )
-
gauge_kv_cache_usage = self._gauge_cls(
name="vllm:kv_cache_usage_perc",
documentation="KV-cache usage. 1 means 100 percent usage.",
@@ -735,39 +684,41 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
)
# Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
- # TODO: in 0.12, only enable if show_hidden_metrics=True
- histogram_time_per_output_token = self._histogram_cls(
- name="vllm:time_per_output_token_seconds",
- documentation=(
- "Histogram of time per output token in seconds."
- "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
- ),
- buckets=[
- 0.01,
- 0.025,
- 0.05,
- 0.075,
- 0.1,
- 0.15,
- 0.2,
- 0.3,
- 0.4,
- 0.5,
- 0.75,
- 1.0,
- 2.5,
- 5.0,
- 7.5,
- 10.0,
- 20.0,
- 40.0,
- 80.0,
- ],
- labelnames=labelnames,
- )
- self.histogram_time_per_output_token = make_per_engine(
- histogram_time_per_output_token, engine_indexes, model_name
- )
+ # With 0.12.x you can enable with --show-hidden-metrics-for-version=0.11
+ # TODO: remove in 0.13.0
+ if self.show_hidden_metrics:
+ histogram_time_per_output_token = self._histogram_cls(
+ name="vllm:time_per_output_token_seconds",
+ documentation=(
+ "Histogram of time per output token in seconds."
+ "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
+ ),
+ buckets=[
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.075,
+ 0.1,
+ 0.15,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 20.0,
+ 40.0,
+ 80.0,
+ ],
+ labelnames=labelnames,
+ )
+ self.histogram_time_per_output_token = make_per_engine(
+ histogram_time_per_output_token, engine_indexes, model_name
+ )
histogram_inter_token_latency = self._histogram_cls(
name="vllm:inter_token_latency_seconds",
@@ -966,20 +917,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.gauge_scheduler_waiting[engine_idx].set(
scheduler_stats.num_waiting_reqs
)
- if self.show_hidden_metrics:
- self.gauge_gpu_cache_usage[engine_idx].set(
- scheduler_stats.kv_cache_usage
- )
self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)
- if self.show_hidden_metrics:
- self.counter_gpu_prefix_cache_queries[engine_idx].inc(
- scheduler_stats.prefix_cache_stats.queries
- )
- self.counter_gpu_prefix_cache_hits[engine_idx].inc(
- scheduler_stats.prefix_cache_stats.hits
- )
-
self.counter_prefix_cache_queries[engine_idx].inc(
scheduler_stats.prefix_cache_stats.queries
)
@@ -1050,7 +989,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.histogram_time_to_first_token[engine_idx].observe(ttft)
for itl in iteration_stats.inter_token_latencies_iter:
self.histogram_inter_token_latency[engine_idx].observe(itl)
- self.histogram_time_per_output_token[engine_idx].observe(itl)
+ if self.show_hidden_metrics:
+ self.histogram_time_per_output_token[engine_idx].observe(itl)
for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason][
diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
index c0b2835c3124c..e32d5bb608b1d 100644
--- a/vllm/v1/outputs.py
+++ b/vllm/v1/outputs.py
@@ -158,7 +158,7 @@ class ModelRunnerOutput:
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
- sampled_token_ids: list[np.ndarray]
+ sampled_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
@@ -220,7 +220,7 @@ def make_empty_encoder_model_runner_output(
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
# No tokens generated yet ⇒ one empty list per request
- sampled_token_ids: list[list[int]] = [np.array([0]) for _ in req_ids]
+ sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
# Pooler outputs are not available yet ⇒ use None placeholders
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
diff --git a/vllm/v1/request.py b/vllm/v1/request.py
index 3d92906fbf4b1..366cdadf5a583 100644
--- a/vllm/v1/request.py
+++ b/vllm/v1/request.py
@@ -121,6 +121,9 @@ class Request:
# The number of requests being preempted by the scheduler
self.num_preemptions = 0
+ # The number of tokens that have been computed remotely.
+ self.num_external_computed_tokens = 0
+
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py
index c6c7e924175f7..5b2d130b0ea42 100644
--- a/vllm/v1/sample/ops/topk_topp_sampler.py
+++ b/vllm/v1/sample/ops/topk_topp_sampler.py
@@ -60,13 +60,20 @@ class TopKTopPSampler(nn.Module):
logprobs_mode not in ("processed_logits", "processed_logprobs")
and rocm_aiter_ops.is_enabled()
):
- import aiter.ops.sampling # noqa: F401
+ try:
+ import aiter.ops.sampling # noqa: F401
- self.aiter_ops = torch.ops.aiter
- logger.info_once(
- "Using aiter sampler on ROCm (lazy import, sampling-only)."
- )
- self.forward = self.forward_hip
+ self.aiter_ops = torch.ops.aiter
+ logger.info_once(
+ "Using aiter sampler on ROCm (lazy import, sampling-only)."
+ )
+ self.forward = self.forward_hip
+ except ImportError:
+ logger.warning_once(
+ "aiter.ops.sampling is not available on ROCm. "
+ "Falling back to forward_native implementation."
+ )
+ self.forward = self.forward_native
else:
self.forward = self.forward_native
diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py
index f31a0cddda9ae..926305d25f56b 100644
--- a/vllm/v1/sample/rejection_sampler.py
+++ b/vllm/v1/sample/rejection_sampler.py
@@ -3,7 +3,6 @@
from dataclasses import replace
-import numpy as np
import torch
import torch.nn as nn
@@ -205,7 +204,7 @@ class RejectionSampler(nn.Module):
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
- ) -> list[np.ndarray]:
+ ) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
@@ -221,7 +220,10 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
- return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)]
+ outputs = [
+ row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
+ ]
+ return outputs
def apply_logits_processors(
self,
diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py
index 39c63fe31ad2c..c75b4f0543c0d 100644
--- a/vllm/v1/sample/sampler.py
+++ b/vllm/v1/sample/sampler.py
@@ -81,7 +81,10 @@ class Sampler(nn.Module):
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
- raw_logprobs = logits.clone()
+ if logits.dtype == torch.float32:
+ raw_logprobs = logits.clone()
+ else:
+ raw_logprobs = logits.to(torch.float32)
# Use float32 for the logits.
logits = logits.to(torch.float32)
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index 5bf2503c3027d..784ccbc04932f 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -40,6 +40,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
+from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
logger = init_logger(__name__)
@@ -65,6 +66,7 @@ class EagleProposer:
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
+ self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
@@ -83,12 +85,15 @@ class EagleProposer:
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
+ self.eagle3_use_aux_hidden_state: bool = (
+ self._get_eagle3_use_aux_hidden_state_from_config()
+ )
self.use_cuda_graph = False
- compilation_config = self.vllm_config.compilation_config
- if compilation_config.mode == CompilationMode.VLLM_COMPILE:
- cudagraph_mode = compilation_config.cudagraph_mode
+ self.compilation_config = self.vllm_config.compilation_config
+ if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
+ cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
CUDAGraphMode.PIECEWISE
):
@@ -103,22 +108,24 @@ class EagleProposer:
and not self.speculative_config.enforce_eager
)
- self.cudagraph_batch_sizes = (
- (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes))
- if self.use_cuda_graph
- else []
- )
-
- self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes)
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=device
)
self.uses_mrope = self.vllm_config.model_config.uses_mrope
if self.uses_mrope:
- # M-RoPE need (3, max_num_tokens)
+ # NOTE: `mrope_positions` is implemented with one additional dummy
+ # position on purpose to make it non-contiguous so that it can work
+ # with torch compile.
+ # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
+
+ # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
+ # the modality of inputs. For text-only inputs, each dimension has
+ # identical position IDs, making M-RoPE functionally equivalent to
+ # 1D-RoPE.
+ # See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
- (3, self.max_num_tokens), dtype=torch.int64, device=device
+ (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
)
else:
# RoPE need (max_num_tokens,)
@@ -266,12 +273,24 @@ class EagleProposer:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
+ num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=num_tokens,
+ num_tokens_padded=num_tokens,
+ )
+
cudagraph_runtime_mode = CUDAGraphMode.NONE
- if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
- num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
+ if (
+ self.use_cuda_graph
+ and num_tokens_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
- num_input_tokens = num_tokens
+ num_input_tokens = num_tokens_dp_padded
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[self.dp_rank] = num_input_tokens
+
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
@@ -295,6 +314,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@@ -357,12 +377,23 @@ class EagleProposer:
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
- if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
- input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
+ batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=batch_size,
+ num_tokens_padded=batch_size,
+ )
+
+ if (
+ self.use_cuda_graph
+ and batch_size_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
- input_batch_size = batch_size
+ input_batch_size = batch_size_dp_padded
cudagraph_runtime_mode = CUDAGraphMode.NONE
+ if batch_size_across_dp is not None:
+ batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
@@ -463,6 +494,7 @@ class EagleProposer:
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
+ num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
@@ -487,7 +519,7 @@ class EagleProposer:
def prepare_next_token_ids_cpu(
self,
- sampled_token_ids: list[np.ndarray],
+ sampled_token_ids: list[list[int]],
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int],
@@ -502,7 +534,7 @@ class EagleProposer:
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
- if token_ids.shape[0] > 0:
+ if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
@@ -513,9 +545,10 @@ class EagleProposer:
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
- return torch.tensor(
+ next_token_ids = torch.tensor(
next_token_ids, dtype=torch.int32, device=self.input_ids.device
)
+ return next_token_ids
def prepare_next_token_ids_padded(
self,
@@ -767,7 +800,10 @@ class EagleProposer:
self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
- if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
+ if (
+ self.use_cuda_graph
+ and num_tokens <= self.compilation_config.max_cudagraph_capture_size
+ ):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
@@ -1019,8 +1055,11 @@ class EagleProposer:
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
- and torch.equal(
- target_embed_tokens.weight, self.model.model.embed_tokens.weight
+ and torch.allclose(
+ target_embed_tokens.weight.cpu(),
+ self.model.model.embed_tokens.weight.cpu(),
+ rtol=1e-5,
+ atol=1e-7,
)
):
share_embeddings = True
@@ -1098,33 +1137,56 @@ class EagleProposer:
self,
num_tokens: int,
use_cudagraphs=True,
+ is_graph_capturing=False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
- if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
- num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
- with set_forward_context(
- None,
- self.vllm_config,
- num_tokens=num_tokens,
- cudagraph_runtime_mode=(
- CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
- ),
+ # FIXME: when using tree-based specdec, adjust number of forward-passes
+ # according to the depth of the tree.
+ for fwd_idx in range(
+ self.num_speculative_tokens if not is_graph_capturing else 1
):
- if self.supports_mm_inputs:
- input_ids = None
- inputs_embeds = self.inputs_embeds[:num_tokens]
- else:
- input_ids = self.input_ids[:num_tokens]
- inputs_embeds = None
+ if fwd_idx <= 1:
+ num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
+ num_tokens_unpadded=num_tokens,
+ num_tokens_padded=num_tokens,
+ )
+ if (
+ cudagraphs_enabled
+ and num_tokens_dp_padded
+ <= self.compilation_config.max_cudagraph_capture_size
+ ):
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(
+ num_tokens_dp_padded
+ )
+ else:
+ num_input_tokens = num_tokens_dp_padded
+ if num_tokens_across_dp is not None:
+ num_tokens_across_dp[self.dp_rank] = num_input_tokens
- self.model(
- input_ids=input_ids,
- positions=self._get_positions(num_tokens),
- hidden_states=self.hidden_states[:num_tokens],
- inputs_embeds=inputs_embeds,
- )
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
+ if cudagraphs_enabled
+ else CUDAGraphMode.NONE,
+ ):
+ if self.supports_mm_inputs:
+ input_ids = None
+ inputs_embeds = self.inputs_embeds[:num_input_tokens]
+ else:
+ input_ids = self.input_ids[:num_input_tokens]
+ inputs_embeds = None
+
+ self.model(
+ input_ids=input_ids,
+ positions=self._get_positions(num_input_tokens),
+ hidden_states=self.hidden_states[:num_input_tokens],
+ inputs_embeds=inputs_embeds,
+ )
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
@@ -1151,6 +1213,22 @@ class EagleProposer:
)
return builder
+ def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
+ """
+ Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
+ hidden states and directly uses the last layer output just like eagle1.
+ They might indicate this by setting "use_aux_hidden_state" to False
+ inside the "eagle_config" dict of their hf_config.
+ """
+ if self.method != "eagle3":
+ return False
+ # Assume that eagle3 heads use aux hidden states by default
+ use_aux_hidden_state = True
+ eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
+ if eagle_config is not None:
+ use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
+ return use_aux_hidden_state
+
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
@@ -1174,6 +1252,28 @@ class EagleProposer:
== 1
), "All eagle layers should belong to the same kv cache group"
+ def _pad_batch_across_dp(
+ self,
+ num_tokens_unpadded: int,
+ num_tokens_padded: int,
+ ) -> tuple[int, torch.Tensor]:
+ # TODO(Flechman): support DBO ubatching
+ ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
+ num_tokens_unpadded=num_tokens_unpadded,
+ parallel_config=self.vllm_config.parallel_config,
+ allow_microbatching=False,
+ allow_dp_padding=self.use_cuda_graph,
+ num_tokens_padded=num_tokens_padded,
+ uniform_decode=None,
+ num_scheduled_tokens_per_request=None,
+ )
+ assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
+
+ num_tokens_dp_padded = num_tokens_padded
+ if num_toks_across_dp is not None:
+ num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
+ return num_tokens_dp_padded, num_toks_across_dp
+
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py
index 378937dba9882..e2f83cb24aa90 100644
--- a/vllm/v1/spec_decode/ngram_proposer.py
+++ b/vllm/v1/spec_decode/ngram_proposer.py
@@ -54,7 +54,7 @@ class NgramProposer:
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(
- [np.array([])] * 1024,
+ [[]] * 1024,
[""] * 1024,
np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
@@ -131,7 +131,7 @@ class NgramProposer:
def propose(
self,
- sampled_token_ids: list[np.ndarray],
+ sampled_token_ids: list[list[int]],
req_ids: list[str],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
@@ -140,7 +140,7 @@ class NgramProposer:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
- num_sampled_ids = sampled_ids.shape[0]
+ num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py
index d76e0ffe778d4..049e335db3254 100644
--- a/vllm/v1/spec_decode/suffix_decoding.py
+++ b/vllm/v1/spec_decode/suffix_decoding.py
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import numpy as np
-
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -34,16 +32,16 @@ class SuffixDecodingProposer:
def propose(
self,
input_batch: InputBatch,
- sampled_token_ids: list[np.ndarray],
+ sampled_token_ids: list[list[int]],
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths.
"""
- draft_token_ids: list[np.ndarray] = []
+ draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
- if sampled_ids.shape[0] == 0:
+ if not sampled_ids:
# Skip speculative decoding for partial prefills.
draft_token_ids.append([])
continue
@@ -72,7 +70,7 @@ class SuffixDecodingProposer:
self.suffix_cache.start_request(req_id, prompt_token_ids)
# Append the newly sampled ids to the suffix cache for this request.
- self.suffix_cache.add_active_response(req_id, sampled_ids.tolist())
+ self.suffix_cache.add_active_response(req_id, sampled_ids)
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input.
diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
index 9f6c19e464308..37ec0fb97e06b 100644
--- a/vllm/v1/worker/block_table.py
+++ b/vllm/v1/worker/block_table.py
@@ -4,7 +4,7 @@
import numpy as np
import torch
-from vllm.distributed import get_dcp_group
+from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
@@ -22,7 +22,7 @@ class BlockTable:
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
- dcp_kv_cache_interleave_size: int,
+ cp_kv_cache_interleave_size: int,
):
"""
Args:
@@ -80,6 +80,13 @@ class BlockTable:
else:
self._kernel_block_arange = None
+ try:
+ self.pcp_world_size = get_pcp_group().world_size
+ self.pcp_rank = get_pcp_group().rank_in_group
+ except AssertionError:
+ # PCP might not be initialized in testing
+ self.pcp_world_size = 1
+ self.pcp_rank = 0
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@@ -87,7 +94,7 @@ class BlockTable:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
- self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
+ self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
@@ -131,14 +138,16 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
- if self.dcp_world_size > 1:
+ total_cp_world_size = self.pcp_world_size * self.dcp_world_size
+ total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
+ if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
- virtual_block_size = self.block_size * self.dcp_world_size
+ virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
@@ -150,16 +159,16 @@ class BlockTable:
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
- // self.dcp_kv_cache_interleave_size
- % self.dcp_world_size
- == self.dcp_rank
+ // self.cp_kv_cache_interleave_size
+ % total_cp_world_size
+ == total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
- // (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
- * self.dcp_kv_cache_interleave_size
- + virtual_block_offsets % self.dcp_kv_cache_interleave_size
+ // (total_cp_world_size * self.cp_kv_cache_interleave_size)
+ * self.cp_kv_cache_interleave_size
+ + virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
@@ -253,12 +262,17 @@ class MultiGroupBlockTable:
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
- dcp_kv_cache_interleave_size: int = 1,
+ cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
+ try:
+ pcp_world_size = get_pcp_group().world_size
+ except AssertionError:
+ # PCP might not be initialized in testing
+ pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
@@ -271,19 +285,21 @@ class MultiGroupBlockTable:
f"must match block_sizes length ({len(block_sizes)})"
)
+ total_cp_world_size = dcp_world_size * pcp_world_size
+
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(
- cdiv(max_model_len, block_size * dcp_world_size),
+ cdiv(max_model_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
- dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]
diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py
index 4420a057d1e58..b080fea1d2dd6 100644
--- a/vllm/v1/worker/cpu_worker.py
+++ b/vllm/v1/worker/cpu_worker.py
@@ -3,6 +3,7 @@
import os
import platform
from collections.abc import Callable
+from typing import Any
import torch
@@ -37,6 +38,9 @@ class CPUWorker(Worker):
self.parallel_config.disable_custom_all_reduce = True
+ # Torch profiler. Enabled and configured through env vars:
+ # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
+ self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
@@ -80,13 +84,13 @@ class CPUWorker(Worker):
self.local_omp_cpuid = "nobind"
else:
local_dp_rank = self.parallel_config.data_parallel_rank_local
- omp_cpuids = omp_cpuids.split("|")
+ omp_cpuids_list = omp_cpuids.split("|")
if local_dp_rank is not None:
world_size = self.parallel_config.world_size
- omp_cpuids = omp_cpuids[
+ omp_cpuids_list = omp_cpuids_list[
local_dp_rank * world_size : (local_dp_rank + 1) * world_size
]
- self.local_omp_cpuid = omp_cpuids[self.rank]
+ self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
@@ -120,7 +124,7 @@ class CPUWorker(Worker):
pass
def determine_available_memory(self) -> int:
- return self.cache_config.cpu_kvcache_space_bytes # type: ignore
+ return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by
diff --git a/vllm/v1/worker/gpu/README.md b/vllm/v1/worker/gpu/README.md
new file mode 100644
index 0000000000000..093f524b3250f
--- /dev/null
+++ b/vllm/v1/worker/gpu/README.md
@@ -0,0 +1,4 @@
+# [Experimental] Model Runner V2
+
+This directory contains the new model runner which is under active development.
+Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.
diff --git a/vllm/v1/worker/gpu/__init__.py b/vllm/v1/worker/gpu/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py
new file mode 100644
index 0000000000000..421fb29a7f87f
--- /dev/null
+++ b/vllm/v1/worker/gpu/async_utils.py
@@ -0,0 +1,92 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from contextlib import contextmanager
+
+import torch
+
+from vllm.v1.outputs import (
+ AsyncModelRunnerOutput,
+ LogprobsTensors,
+ ModelRunnerOutput,
+ SamplerOutput,
+)
+
+
+class AsyncOutput(AsyncModelRunnerOutput):
+ def __init__(
+ self,
+ model_runner_output: ModelRunnerOutput,
+ sampler_output: SamplerOutput,
+ num_sampled_tokens: torch.Tensor,
+ copy_stream: torch.cuda.Stream,
+ copy_event: torch.cuda.Event,
+ ):
+ self.model_runner_output = model_runner_output
+ self.sampler_output = sampler_output
+ self.num_sampled_tokens = num_sampled_tokens
+ self.copy_stream = copy_stream
+ self.copy_event = copy_event
+
+ default_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(self.copy_stream):
+ self.copy_stream.wait_stream(default_stream)
+
+ # NOTE(woosuk): We must ensure that CPU tensors are not freed
+ # before the device-to-host copy is fully completed. For instance,
+ # operations like
+ # self.sampled_token_np = ...to("cpu", non_blocking=True).numpy()
+ # are unsafe because the underlying CPU tensor can be prematurely freed and
+ # reused by other tensors before the asynchronous copy finishes, potentially
+ # causing race conditions. To prevent this, we delay freeing by holding
+ # references until the copy event signals completion.
+ # Likewise, we also need to keep the reference to the GPU tensors.
+ # This is done by keeping the reference to sampler_output and
+ # model_runner_output.
+ self.sampled_token_ids = sampler_output.sampled_token_ids.to(
+ "cpu", non_blocking=True
+ )
+ if sampler_output.logprobs_tensors is not None:
+ self.logprobs_tensors: LogprobsTensors | None = (
+ sampler_output.logprobs_tensors.to_cpu_nonblocking()
+ )
+ else:
+ self.logprobs_tensors = None
+ self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True)
+ self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
+ if self.model_runner_output.prompt_logprobs_dict:
+ for k, v in self.model_runner_output.prompt_logprobs_dict.items():
+ if v is not None:
+ self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
+ else:
+ self.prompt_logprobs_dict[k] = None
+ self.copy_event.record(self.copy_stream)
+
+ def get_output(self) -> ModelRunnerOutput:
+ self.copy_event.synchronize()
+ num_sampled_tokens_np = self.num_sampled_tokens.numpy()
+
+ # NOTE(woosuk): The following code is to ensure compatibility with
+ # the existing model runner.
+ # Going forward, we should keep the data structures as NumPy arrays
+ # rather than Python lists.
+ sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
+ num_reqs = len(sampled_token_ids)
+ for i in range(num_reqs):
+ del sampled_token_ids[i][num_sampled_tokens_np[i] :]
+ self.model_runner_output.sampled_token_ids = sampled_token_ids
+
+ if self.logprobs_tensors is not None:
+ self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
+ self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
+ return self.model_runner_output
+
+
+@contextmanager
+def async_barrier(event: torch.cuda.Event | None):
+ if event is not None:
+ event.synchronize()
+ try:
+ yield
+ finally:
+ if event is not None:
+ event.record()
diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py
new file mode 100644
index 0000000000000..4510a1c5ca1e9
--- /dev/null
+++ b/vllm/v1/worker/gpu/attn_utils.py
@@ -0,0 +1,191 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Sequence
+from typing import Any, cast
+
+import numpy as np
+import torch
+
+from vllm.attention.backends.abstract import AttentionBackend
+from vllm.config import VllmConfig, get_layers_from_vllm_config
+from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
+from vllm.v1.attention.backends.utils import (
+ AttentionMetadataBuilder,
+ CommonAttentionMetadata,
+)
+from vllm.v1.kv_cache_interface import (
+ AttentionSpec,
+ KVCacheConfig,
+ KVCacheSpec,
+)
+from vllm.v1.utils import CpuGpuBuffer
+from vllm.v1.worker.utils import bind_kv_cache
+
+
+def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
+ kv_cache_spec: dict[str, KVCacheSpec] = {}
+ layer_type = cast(type[Any], AttentionLayerBase)
+ attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
+ for layer_name, attn_module in attn_layers.items():
+ # Skip modules that don't need KV cache (eg encoder-only attention)
+ if spec := attn_module.get_kv_cache_spec(vllm_config):
+ kv_cache_spec[layer_name] = spec
+ return kv_cache_spec
+
+
+def init_attn_backend(
+ kv_cache_config: KVCacheConfig,
+ vllm_config: VllmConfig,
+ device: torch.device,
+):
+ attn_backends: dict[str, type[AttentionBackend]] = {}
+ attn_metadata_builders: list[AttentionMetadataBuilder] = []
+ flashinfer_workspace: torch.Tensor | None = None
+ for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
+ layer_names = kv_cache_group_spec.layer_names
+ any_layer_name = next(iter(layer_names))
+
+ layer_type = cast(type[Any], AttentionLayerBase)
+ attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
+ attn_backend = attn_layers[any_layer_name].get_attn_backend()
+ for layer_name in layer_names:
+ attn_backends[layer_name] = attn_backend
+
+ attn_metadata_builder = attn_backend.get_builder_cls()(
+ kv_cache_group_spec.kv_cache_spec,
+ layer_names,
+ vllm_config,
+ device,
+ )
+ attn_metadata_builders.append(attn_metadata_builder) # type: ignore
+
+ if "FLASHINFER" in attn_backend.get_name():
+ if flashinfer_workspace is None:
+ flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
+ else:
+ attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
+ return attn_backends, attn_metadata_builders
+
+
+def _allocate_kv_cache(
+ kv_cache_config: KVCacheConfig,
+ device: torch.device,
+):
+ kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
+ for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
+ tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
+ for layer_name in kv_cache_tensor.shared_by:
+ kv_cache_raw_tensors[layer_name] = tensor
+
+ layer_names = set()
+ for group in kv_cache_config.kv_cache_groups:
+ for layer_name in group.layer_names:
+ layer_names.add(layer_name)
+ assert layer_names == set(kv_cache_raw_tensors.keys()), (
+ "Some layers are not correctly initialized"
+ )
+ return kv_cache_raw_tensors
+
+
+def _reshape_kv_cache(
+ kv_cache_config: KVCacheConfig,
+ kv_cache_raw_tensors: dict[str, torch.Tensor],
+ attn_backends: dict[str, AttentionBackend],
+) -> dict[str, torch.Tensor]:
+ kv_caches: dict[str, torch.Tensor] = {}
+ for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
+ kv_cache_spec = kv_cache_group_spec.kv_cache_spec
+ assert isinstance(kv_cache_spec, AttentionSpec)
+ for layer_name in kv_cache_group_spec.layer_names:
+ raw_tensor = kv_cache_raw_tensors[layer_name]
+ assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
+ num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
+
+ attn_backend = attn_backends[layer_name]
+ kv_cache_shape = attn_backend.get_kv_cache_shape(
+ num_blocks,
+ kv_cache_spec.block_size,
+ kv_cache_spec.num_kv_heads,
+ kv_cache_spec.head_size,
+ )
+
+ # FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
+ try:
+ kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
+ assert len(kv_cache_stride_order) == len(kv_cache_shape)
+ except (AttributeError, NotImplementedError):
+ kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
+
+ kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
+ inv_order = [
+ kv_cache_stride_order.index(i)
+ for i in range(len(kv_cache_stride_order))
+ ]
+
+ dtype = kv_cache_spec.dtype
+ raw_tensor = raw_tensor.view(dtype)
+ raw_tensor = raw_tensor.view(kv_cache_shape)
+ kv_caches[layer_name] = raw_tensor.permute(*inv_order)
+ return kv_caches
+
+
+def init_kv_cache(
+ runner_kv_caches: list[torch.Tensor],
+ forward_context: dict[str, Any],
+ kv_cache_config: KVCacheConfig,
+ attn_backends: dict[str, AttentionBackend],
+ device: torch.device,
+) -> None:
+ kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
+ kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
+ bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
+
+
+def build_attn_metadata(
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ num_reqs: int,
+ num_tokens: int,
+ query_start_loc: CpuGpuBuffer,
+ seq_lens: torch.Tensor,
+ seq_lens_np: np.ndarray,
+ num_computed_tokens_cpu: torch.Tensor | None,
+ block_tables: Sequence[torch.Tensor],
+ slot_mappings: torch.Tensor,
+ kv_cache_config: KVCacheConfig,
+) -> dict[str, Any]:
+ query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
+ query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
+ max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
+ seq_lens = seq_lens[:num_reqs]
+ seq_lens_cpu = torch.from_numpy(seq_lens_np)
+ max_seq_len = int(seq_lens_np.max())
+
+ attn_metadata: dict[str, Any] = {}
+ kv_cache_groups = kv_cache_config.kv_cache_groups
+ for i, kv_cache_spec in enumerate(kv_cache_groups):
+ block_table = block_tables[i]
+ slot_mapping = slot_mappings[i]
+
+ common_attn_metadata = CommonAttentionMetadata(
+ query_start_loc=query_start_loc_gpu,
+ query_start_loc_cpu=query_start_loc_cpu,
+ seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens_cpu,
+ max_seq_len=max_seq_len,
+ num_computed_tokens_cpu=num_computed_tokens_cpu,
+ num_reqs=num_reqs,
+ num_actual_tokens=num_tokens,
+ max_query_len=max_query_len,
+ block_table_tensor=block_table,
+ slot_mapping=slot_mapping,
+ causal=True,
+ )
+
+ attn_metadata_builder = attn_metadata_builders[i]
+ metadata = attn_metadata_builder.build(
+ common_prefix_len=0,
+ common_attn_metadata=common_attn_metadata,
+ )
+ for layer_name in kv_cache_spec.layer_names:
+ attn_metadata[layer_name] = metadata
+ return attn_metadata
diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py
new file mode 100644
index 0000000000000..b31e9b179d26c
--- /dev/null
+++ b/vllm/v1/worker/gpu/block_table.py
@@ -0,0 +1,314 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable
+
+import torch
+
+from vllm.attention.backends.utils import PAD_SLOT_ID
+from vllm.triton_utils import tl, triton
+from vllm.utils.math_utils import cdiv
+from vllm.v1.utils import CpuGpuBuffer
+
+
+class BlockTables:
+ def __init__(
+ self,
+ block_sizes: list[int],
+ max_num_reqs: int,
+ max_num_batched_tokens: int,
+ max_model_len: int,
+ device: torch.device,
+ pin_memory: bool,
+ ):
+ self.block_sizes = block_sizes
+ self.max_num_reqs = max_num_reqs
+ self.max_num_batched_tokens = max_num_batched_tokens
+ self.max_model_len = max_model_len
+ self.device = device
+ self.pin_memory = pin_memory
+
+ self.num_kv_cache_groups = len(self.block_sizes)
+ # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
+ self.block_tables: list[torch.Tensor] = []
+ for i in range(self.num_kv_cache_groups):
+ block_size = self.block_sizes[i]
+ max_num_blocks = cdiv(self.max_model_len, block_size)
+ block_table = torch.zeros(
+ self.max_num_reqs,
+ max_num_blocks,
+ dtype=torch.int32,
+ device=self.device,
+ )
+ self.block_tables.append(block_table)
+ self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
+
+ # Block tables used for model's forward pass.
+ # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
+ self.input_block_tables: list[torch.Tensor] = [
+ torch.zeros_like(block_table) for block_table in self.block_tables
+ ]
+ self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
+
+ self.block_table_strides = torch.tensor(
+ [b.stride(0) for b in self.block_tables],
+ dtype=torch.int64,
+ device=self.device,
+ )
+ self.block_sizes_tensor = torch.tensor(
+ self.block_sizes, dtype=torch.int32, device=self.device
+ )
+ self.num_blocks = torch.zeros(
+ self.num_kv_cache_groups,
+ self.max_num_reqs,
+ dtype=torch.int32,
+ device=self.device,
+ )
+ self.slot_mappings = torch.zeros(
+ self.num_kv_cache_groups,
+ self.max_num_batched_tokens,
+ dtype=torch.int64,
+ device=self.device,
+ )
+
+ # Misc buffers.
+ self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
+ self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
+ self.cu_num_new_blocks = self._make_buffer(
+ self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
+ )
+
+ def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
+ return CpuGpuBuffer(
+ *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
+ )
+
+ def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
+ # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
+ ptrs_tensor_cpu = torch.tensor(
+ [t.data_ptr() for t in x],
+ dtype=torch.uint64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+ return ptrs_tensor_cpu.to(self.device, non_blocking=True)
+
+ def append_block_ids(
+ self,
+ # [num_reqs]
+ req_indices: list[int],
+ # [num_kv_cache_groups, num_reqs + 1]
+ cu_num_new_blocks: tuple[list[int], ...],
+ # [num_kv_cache_groups, num_new_blocks]
+ new_block_ids: tuple[list[int], ...],
+ # [num_reqs]
+ overwrite: list[bool],
+ ) -> None:
+ num_reqs = len(req_indices)
+ self.req_indices.np[:num_reqs] = req_indices
+ self.overwrite.np[:num_reqs] = overwrite
+ for i in range(self.num_kv_cache_groups):
+ self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
+
+ # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
+ # no clear upper bound to the number of new blocks in a single step.
+ # NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
+ # guarantee that the buffer is not freed before the copy is completed.
+ self.new_block_ids_cpu = torch.empty(
+ self.num_kv_cache_groups,
+ max(len(x) for x in new_block_ids),
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+ new_block_ids_np = self.new_block_ids_cpu.numpy()
+ for i in range(self.num_kv_cache_groups):
+ new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
+ new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
+
+ _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
+ self.req_indices.copy_to_gpu(num_reqs),
+ self.cu_num_new_blocks.copy_to_gpu(),
+ self.cu_num_new_blocks.gpu.stride(0),
+ new_block_ids_gpu,
+ new_block_ids_gpu.stride(0),
+ self.overwrite.copy_to_gpu(num_reqs),
+ self.block_table_strides,
+ self.block_table_ptrs,
+ self.num_blocks,
+ self.num_blocks.stride(0),
+ BLOCK_SIZE=1024, # type: ignore
+ )
+
+ def gather_block_tables(
+ self,
+ idx_mapping: torch.Tensor,
+ ) -> tuple[torch.Tensor, ...]:
+ num_reqs = idx_mapping.shape[0]
+ _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
+ idx_mapping,
+ self.block_table_ptrs,
+ self.input_block_table_ptrs,
+ self.block_table_strides,
+ self.num_blocks,
+ self.num_blocks.stride(0),
+ BLOCK_SIZE=1024, # type: ignore
+ )
+ return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
+
+ def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
+ return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
+
+ def compute_slot_mappings(
+ self,
+ query_start_loc: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ num_reqs = query_start_loc.shape[0] - 1
+ num_tokens = positions.shape[0]
+ num_groups = self.num_kv_cache_groups
+ _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
+ num_tokens,
+ self.max_num_batched_tokens,
+ query_start_loc,
+ positions,
+ self.input_block_table_ptrs,
+ self.block_table_strides,
+ self.block_sizes_tensor,
+ self.slot_mappings,
+ self.slot_mappings.stride(0),
+ PAD_ID=PAD_SLOT_ID,
+ BLOCK_SIZE=1024, # type: ignore
+ )
+ return self.slot_mappings[:, :num_tokens]
+
+ def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
+ self.slot_mappings.fill_(PAD_SLOT_ID)
+ return self.slot_mappings[:, :num_tokens]
+
+
+@triton.jit
+def _append_block_ids_kernel(
+ # Inputs
+ req_indices, # [num_reqs]
+ cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
+ cu_num_new_blocks_stride,
+ new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
+ new_block_ids_stride,
+ overwrite, # [num_reqs]
+ block_table_strides, # [num_kv_cache_groups]
+ # Outputs
+ block_table_ptrs, # [num_kv_cache_groups]
+ num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
+ num_blocks_stride,
+ # Constants
+ BLOCK_SIZE: tl.constexpr,
+):
+ group_id = tl.program_id(0)
+ batch_idx = tl.program_id(1)
+ req_idx = tl.load(req_indices + batch_idx)
+ do_overwrite = tl.load(overwrite + batch_idx)
+
+ group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
+ start_idx = tl.load(group_new_blocks_ptr + batch_idx)
+ end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
+ num_new_blocks = end_idx - start_idx
+
+ group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
+ dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0
+ dst_end_idx = dst_start_idx + num_new_blocks
+ tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
+
+ # Destination
+ block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
+ block_table_stride = tl.load(block_table_strides + group_id)
+ row_ptr = block_table_ptr + req_idx * block_table_stride
+
+ group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
+ for i in range(0, num_new_blocks, BLOCK_SIZE):
+ offset = i + tl.arange(0, BLOCK_SIZE)
+ block_ids = tl.load(
+ group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
+ )
+ tl.store(
+ row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
+ )
+
+
+@triton.jit
+def _gather_block_tables_kernel(
+ batch_idx_to_req_idx, # [batch_size]
+ src_block_table_ptrs, # [num_kv_cache_groups]
+ dst_block_table_ptrs, # [num_kv_cache_groups]
+ block_table_strides, # [num_kv_cache_groups]
+ num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
+ num_blocks_stride,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # kv cache group id
+ group_id = tl.program_id(0)
+ batch_idx = tl.program_id(1)
+ req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
+
+ group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
+ num_blocks = tl.load(group_num_blocks_ptr + req_idx)
+
+ stride = tl.load(block_table_strides + group_id)
+ src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
+ src_row_ptr = src_block_table_ptr + req_idx * stride
+ dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
+ dst_row_ptr = dst_block_table_ptr + batch_idx * stride
+
+ for i in tl.range(0, num_blocks, BLOCK_SIZE):
+ offset = i + tl.arange(0, BLOCK_SIZE)
+ block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
+ tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
+
+
+@triton.jit
+def _compute_slot_mappings_kernel(
+ num_tokens,
+ max_num_tokens,
+ cu_num_tokens, # [num_reqs + 1]
+ pos, # [num_tokens]
+ block_table_ptrs, # [num_kv_cache_groups]
+ block_table_strides, # [num_kv_cache_groups]
+ page_sizes, # [num_kv_cache_groups]
+ slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
+ slot_mappings_stride,
+ PAD_ID: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # kv cache group id
+ group_id = tl.program_id(0)
+ req_idx = tl.program_id(1)
+ slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
+
+ if req_idx == tl.num_programs(1) - 1:
+ # Pad remaining slots to -1. This is needed for CUDA graphs.
+ for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
+ offset = i + tl.arange(0, BLOCK_SIZE)
+ tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
+ return
+
+ block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
+ block_table_stride = tl.load(block_table_strides + group_id)
+ page_size = tl.load(page_sizes + group_id)
+
+ start_idx = tl.load(cu_num_tokens + req_idx)
+ end_idx = tl.load(cu_num_tokens + req_idx + 1)
+ for i in range(start_idx, end_idx, BLOCK_SIZE):
+ offset = i + tl.arange(0, BLOCK_SIZE)
+ positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
+ block_indices = positions // page_size
+ block_numbers = tl.load(
+ block_table_ptr + req_idx * block_table_stride + block_indices
+ )
+ slot_ids = block_numbers * page_size + positions % page_size
+ tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
+
+
+@triton.jit
+def _load_ptr(ptr_to_ptr, elem_dtype):
+ ptr = tl.load(ptr_to_ptr)
+ ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
+ return tl.multiple_of(ptr, 16)
diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py
new file mode 100644
index 0000000000000..ba783e2d0c6fb
--- /dev/null
+++ b/vllm/v1/worker/gpu/cudagraph_utils.py
@@ -0,0 +1,205 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from unittest.mock import patch
+
+import numpy as np
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+from vllm.config import VllmConfig
+from vllm.config.compilation import CUDAGraphMode
+from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
+from vllm.forward_context import set_forward_context
+from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
+from vllm.v1.worker.gpu.block_table import BlockTables
+from vllm.v1.worker.gpu.input_batch import InputBuffers
+
+
+class CudaGraphManager:
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ self.vllm_config = vllm_config
+ self.scheduler_config = vllm_config.scheduler_config
+ self.device = device
+
+ self.max_model_len = vllm_config.model_config.max_model_len
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.dp_size = vllm_config.parallel_config.data_parallel_size
+ self.compilation_config = vllm_config.compilation_config
+ assert self.compilation_config is not None
+
+ if self.compilation_config.cudagraph_mode is None:
+ self.cudagraph_mode = CUDAGraphMode.NONE
+ else:
+ self.cudagraph_mode = self.compilation_config.cudagraph_mode
+ if self.compilation_config.cudagraph_capture_sizes is not None:
+ cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
+ # Limit the cudagraph sizes to the max decode batch size.
+ self.cudagraph_sizes = [
+ x for x in cudagraph_sizes if x <= self.max_num_reqs
+ ]
+ else:
+ self.cudagraph_sizes = []
+ self.padded_sizes = self._init_padded_sizes()
+
+ self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
+ self.pool = torch.cuda.graph_pool_handle()
+ self.hidden_states: torch.Tensor | None = None
+
+ def _init_padded_sizes(self) -> dict[int, int]:
+ if not self.cudagraph_mode.has_full_cudagraphs():
+ # Full cuda graphs are not used.
+ return {}
+ if not self.cudagraph_sizes:
+ return {}
+
+ padded_sizes: dict[int, int] = {}
+ for i in range(1, self.cudagraph_sizes[-1] + 1):
+ for x in self.cudagraph_sizes:
+ if i <= x:
+ padded_sizes[i] = x
+ break
+ return padded_sizes
+
+ def needs_capture(self) -> bool:
+ return len(self.padded_sizes) > 0
+
+ def get_cudagraph_size(
+ self,
+ scheduler_output: SchedulerOutput,
+ num_tokens_after_padding: int,
+ ) -> int | None:
+ if not self.cudagraph_mode.has_full_cudagraphs():
+ return None
+ if self.cudagraph_mode != CUDAGraphMode.FULL:
+ # TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
+ all_decode = all(
+ x == 1 for x in scheduler_output.num_scheduled_tokens.values()
+ )
+ if not all_decode:
+ # Prefill is included.
+ return None
+ return self.padded_sizes.get(num_tokens_after_padding)
+
+ def capture_graph(
+ self,
+ batch_size: int,
+ model: nn.Module,
+ input_buffers: InputBuffers,
+ block_tables: BlockTables,
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ kv_cache_config: KVCacheConfig,
+ ) -> None:
+ assert batch_size not in self.graphs
+
+ # Prepare dummy inputs.
+ input_ids = input_buffers.input_ids.gpu[:batch_size]
+ positions = input_buffers.positions[:batch_size]
+
+ input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
+ input_buffers.query_start_loc.np[batch_size:] = batch_size
+ input_buffers.query_start_loc.copy_to_gpu()
+ # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
+ # for seq_lens. This leads to a mismatch between seq_lens (GPU) and
+ # seq_lens_np (CPU), which might cause issues in some attention backends.
+ input_buffers.seq_lens[:batch_size] = 1
+ input_buffers.seq_lens[batch_size:] = 0
+
+ input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
+ slot_mappings = block_tables.slot_mappings[:, :batch_size]
+
+ attn_metadata = build_attn_metadata(
+ attn_metadata_builders=attn_metadata_builders,
+ num_reqs=batch_size,
+ num_tokens=batch_size,
+ query_start_loc=input_buffers.query_start_loc,
+ seq_lens=input_buffers.seq_lens,
+ seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
+ num_computed_tokens_cpu=None, # FIXME
+ block_tables=input_block_tables,
+ slot_mappings=slot_mappings,
+ kv_cache_config=kv_cache_config,
+ )
+ if self.dp_size > 1:
+ num_tokens_across_dp = torch.full(
+ (self.dp_size,),
+ batch_size,
+ dtype=torch.int32,
+ device="cpu",
+ )
+ else:
+ num_tokens_across_dp = None
+
+ # Warm up.
+ with set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=batch_size,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ num_tokens_across_dp=num_tokens_across_dp,
+ ):
+ hidden_states = model(
+ input_ids=input_ids,
+ positions=positions,
+ )
+ if self.hidden_states is None:
+ self.hidden_states = torch.empty_like(hidden_states)
+
+ # Capture the graph.
+ graph = torch.cuda.CUDAGraph()
+ with (
+ patch("torch.cuda.empty_cache", lambda: None),
+ set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=batch_size,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ num_tokens_across_dp=num_tokens_across_dp,
+ ),
+ torch.cuda.graph(graph, self.pool),
+ ):
+ hidden_states = model(
+ input_ids=input_ids,
+ positions=positions,
+ )
+ self.hidden_states[:batch_size] = hidden_states
+ self.graphs[batch_size] = graph
+
+ @torch.inference_mode()
+ def capture(
+ self,
+ model: nn.Module,
+ input_buffers: InputBuffers,
+ block_tables: BlockTables,
+ attn_metadata_builders: list[AttentionMetadataBuilder],
+ kv_cache_config: KVCacheConfig,
+ ) -> None:
+ assert self.needs_capture()
+ # Capture larger graphs first.
+ sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
+ if is_global_first_rank():
+ sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
+
+ with graph_capture(device=self.device):
+ for batch_size in sizes_to_capture:
+ self.capture_graph(
+ batch_size,
+ model,
+ input_buffers,
+ block_tables,
+ attn_metadata_builders,
+ kv_cache_config,
+ )
+
+ def run(self, batch_size: int) -> torch.Tensor:
+ assert batch_size in self.graphs
+ self.graphs[batch_size].replay()
+ assert self.hidden_states is not None
+ return self.hidden_states[:batch_size]
diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py
new file mode 100644
index 0000000000000..9bfc7f25bef3a
--- /dev/null
+++ b/vllm/v1/worker/gpu/dp_utils.py
@@ -0,0 +1,22 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+import torch.distributed as dist
+
+from vllm.distributed.parallel_state import get_dp_group
+
+
+def get_batch_metadata_across_dp(
+ num_tokens: int,
+ cudagraph_size: int,
+ dp_size: int,
+ dp_rank: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ assert dp_size > 1
+ # Use CPU group to avoid CPU-GPU synchronization.
+ group = get_dp_group().cpu_group
+ tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
+ tensor[0][dp_rank] = num_tokens
+ tensor[1][dp_rank] = cudagraph_size
+ dist.all_reduce(tensor, group=group)
+ return tensor[0], tensor[1]
diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py
new file mode 100644
index 0000000000000..2a7048ae3c0e0
--- /dev/null
+++ b/vllm/v1/worker/gpu/input_batch.py
@@ -0,0 +1,397 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from dataclasses import dataclass
+from typing import Any
+
+import numba
+import numpy as np
+import torch
+
+from vllm.triton_utils import tl, triton
+from vllm.utils import random_uuid
+from vllm.utils.math_utils import cdiv
+from vllm.v1.utils import CpuGpuBuffer
+
+
+class InputBuffers:
+ def __init__(
+ self,
+ max_num_reqs: int,
+ max_num_tokens: int,
+ hidden_size: int,
+ vocab_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ pin_memory: bool,
+ ):
+ self.max_num_reqs = max_num_reqs
+ self.max_num_tokens = max_num_tokens
+ self.device = device
+ self.pin_memory = pin_memory
+
+ self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
+ self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
+ self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
+ self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
+ self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
+ self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
+
+ # Spec decoding.
+ self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
+
+ # Structured outputs.
+ self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
+ self.grammar_bitmask = self._make_buffer(
+ max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
+ )
+
+ def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
+ return CpuGpuBuffer(
+ *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
+ )
+
+
+@dataclass
+class InputBatch:
+ # batch_idx -> req_id
+ req_ids: list[str]
+ num_reqs: int
+
+ # batch_idx -> req_state_idx
+ idx_mapping: torch.Tensor
+ idx_mapping_np: np.ndarray
+
+ # [num_reqs]
+ # batch_idx -> num_scheduled_tokens
+ num_scheduled_tokens: np.ndarray
+ # sum(num_scheduled_tokens)
+ num_tokens: int
+ num_tokens_after_padding: int
+ num_draft_tokens: int
+
+ # [num_reqs + 1]
+ query_start_loc: torch.Tensor
+ query_start_loc_np: np.ndarray
+ # [num_reqs]
+ seq_lens: torch.Tensor
+ seq_lens_np: np.ndarray
+
+ # [num_tokens_after_padding]
+ input_ids: torch.Tensor
+ # [num_tokens_after_padding]
+ positions: torch.Tensor
+
+ # layer_name -> Metadata
+ attn_metadata: dict[str, Any]
+
+ # [total_num_logits]
+ logits_indices: torch.Tensor
+ # [num_reqs + 1]
+ cu_num_logits: torch.Tensor
+
+ @classmethod
+ def make_dummy(
+ cls,
+ num_reqs: int,
+ num_tokens: int,
+ input_buffers: InputBuffers,
+ device: torch.device,
+ ) -> "InputBatch":
+ assert 0 < num_reqs <= num_tokens
+ req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
+ idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
+ idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
+ num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
+ num_scheduled_tokens[-1] += num_tokens % num_reqs
+ assert int(num_scheduled_tokens.sum()) == num_tokens
+
+ input_buffers.query_start_loc.np[0] = 0
+ input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
+ num_scheduled_tokens
+ )
+ input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
+ query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
+ query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
+ # seq_len equals to query_len
+ seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
+ seq_lens_np[-1] += num_tokens % num_reqs
+ input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
+ input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
+ input_buffers.seq_lens[num_reqs:] = 0
+ seq_lens = input_buffers.seq_lens[:num_reqs]
+
+ input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
+ positions = input_buffers.positions[:num_tokens]
+ # attn_metadata = defaultdict(lambda: None)
+ logits_indices = query_start_loc[1:] - 1
+ cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
+ return cls(
+ req_ids=req_ids,
+ num_reqs=num_reqs,
+ idx_mapping=idx_mapping,
+ idx_mapping_np=idx_mapping_np,
+ num_scheduled_tokens=num_scheduled_tokens,
+ num_tokens=num_tokens,
+ num_tokens_after_padding=num_tokens,
+ num_draft_tokens=0,
+ query_start_loc=query_start_loc,
+ query_start_loc_np=query_start_loc_np,
+ seq_lens=seq_lens,
+ seq_lens_np=seq_lens_np,
+ input_ids=input_ids,
+ positions=positions,
+ attn_metadata=None, # type: ignore
+ logits_indices=logits_indices,
+ cu_num_logits=cu_num_logits,
+ )
+
+
+@numba.njit(cache=True)
+def _prepare_prefill_inputs(
+ idx_mapping: np.ndarray, # [B]
+ query_lens: np.ndarray, # [B]
+ query_start_loc: np.ndarray, # [B + 1]
+ prefill_token_ids: np.ndarray, # [N, max_model_len]
+ num_computed_prefill_tokens: np.ndarray, # [N]
+ input_ids: np.ndarray, # [num_input_tokens]
+) -> None:
+ num_reqs = idx_mapping.shape[0]
+ query_starts = query_start_loc[:num_reqs]
+ query_ends = query_start_loc[1 : num_reqs + 1]
+ starts = num_computed_prefill_tokens[idx_mapping]
+ ends = starts + query_lens
+ for i in range(num_reqs):
+ input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
+ idx_mapping[i], starts[i] : ends[i]
+ ]
+
+
+def prepare_prefill_inputs(
+ idx_mapping: np.ndarray,
+ num_scheduled_tokens: np.ndarray,
+ query_start_loc: np.ndarray,
+ prefill_token_ids: np.ndarray,
+ num_computed_prefill_tokens: np.ndarray,
+ input_ids: np.ndarray,
+) -> None:
+ _prepare_prefill_inputs(
+ idx_mapping,
+ num_scheduled_tokens,
+ query_start_loc,
+ prefill_token_ids,
+ num_computed_prefill_tokens,
+ input_ids,
+ )
+
+
+@triton.jit
+def _prepare_pos_seq_lens_kernel(
+ pos_ptr,
+ seq_lens_ptr,
+ idx_mapping_ptr,
+ query_start_loc_ptr,
+ num_computed_tokens_ptr,
+ max_num_reqs,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_id = tl.program_id(0)
+ num_reqs = tl.num_programs(0) - 1
+ if req_id == num_reqs:
+ # Pad unused seq_lens as 0 for full CUDA graphs.
+ for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < max_num_reqs
+ tl.store(seq_lens_ptr + block, 0, mask=mask)
+ return
+
+ req_state_idx = tl.load(idx_mapping_ptr + req_id)
+ num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
+
+ start = tl.load(query_start_loc_ptr + req_id)
+ end = tl.load(query_start_loc_ptr + req_id + 1)
+ query_len = end - start
+
+ seq_len = num_computed_tokens + query_len
+ tl.store(seq_lens_ptr + req_id, seq_len)
+
+ for i in tl.range(0, query_len, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < query_len
+ pos = num_computed_tokens + block
+ tl.store(pos_ptr + start + block, pos, mask=mask)
+
+
+def prepare_pos_seq_lens(
+ idx_mapping: torch.Tensor,
+ query_start_loc: torch.Tensor,
+ num_computed_tokens: torch.Tensor,
+ pos: torch.Tensor,
+ seq_lens: torch.Tensor,
+) -> None:
+ num_reqs = idx_mapping.shape[0]
+ # NOTE(woosuk): We do +1 because the last thread block is used
+ # to pad unused seq_lens as 0 for full CUDA graphs.
+ _prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
+ pos,
+ seq_lens,
+ idx_mapping,
+ query_start_loc,
+ num_computed_tokens,
+ seq_lens.shape[0],
+ BLOCK_SIZE=1024,
+ )
+
+
+@triton.jit
+def _combine_sampled_and_draft_tokens_kernel(
+ input_ids_ptr,
+ idx_mapping_ptr,
+ last_sampled_tokens_ptr,
+ query_start_loc_ptr,
+ seq_lens_ptr,
+ prefill_len_ptr,
+ draft_tokens_ptr,
+ draft_tokens_stride,
+ cu_num_logits_ptr,
+ logits_indices_ptr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
+
+ # Get the number of logits and draft tokens.
+ cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
+ cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
+ num_logits = cu_num_logits_end - cu_num_logits_start
+ num_draft_tokens = num_logits - 1
+
+ # Compute the logits indices.
+ block = tl.arange(0, BLOCK_SIZE)
+ query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
+ logits_start = query_end - num_logits
+ tl.store(
+ logits_indices_ptr + cu_num_logits_start + block,
+ logits_start + block,
+ mask=block < num_logits,
+ )
+
+ seq_len = tl.load(seq_lens_ptr + batch_idx)
+ prefill_len = tl.load(prefill_len_ptr + req_state_idx)
+ if seq_len <= prefill_len:
+ # Handling prefill tokens. No sampled or draft tokens.
+ return
+
+ # Write the last sampled token ID to input_ids.
+ last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
+ tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
+
+ # Write the draft tokens (if any) to input_ids.
+ if num_draft_tokens > 0:
+ mask = block < num_draft_tokens
+ draft_tokens = tl.load(
+ draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
+ mask=mask,
+ )
+ tl.store(
+ input_ids_ptr + query_end - num_draft_tokens + block,
+ draft_tokens,
+ mask=mask,
+ )
+
+
+def combine_sampled_and_draft_tokens(
+ input_ids: torch.Tensor,
+ idx_mapping: torch.Tensor,
+ last_sampled_tokens: torch.Tensor,
+ query_start_loc: torch.Tensor,
+ seq_lens: torch.Tensor,
+ prefill_len: torch.Tensor,
+ draft_tokens: torch.Tensor,
+ cu_num_logits: torch.Tensor,
+ num_logits: int,
+) -> torch.Tensor:
+ num_reqs = seq_lens.shape[0]
+ num_speculative_steps = draft_tokens.shape[-1]
+
+ logits_indices = torch.empty(
+ num_logits,
+ dtype=torch.int64,
+ device=input_ids.device,
+ )
+ _combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
+ input_ids,
+ idx_mapping,
+ last_sampled_tokens,
+ query_start_loc,
+ seq_lens,
+ prefill_len,
+ draft_tokens,
+ draft_tokens.stride(0),
+ cu_num_logits,
+ logits_indices,
+ # NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
+ # in addition to all draft tokens.
+ BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
+ )
+ return logits_indices
+
+
+@triton.jit
+def _post_update_kernel(
+ idx_mapping_ptr,
+ num_computed_tokens_ptr,
+ last_sampled_tokens_ptr,
+ sampled_tokens_ptr,
+ sampled_tokens_stride,
+ num_sampled_ptr,
+ num_rejected_ptr,
+ query_start_loc_ptr,
+):
+ req_id = tl.program_id(0)
+ req_state_idx = tl.load(idx_mapping_ptr + req_id)
+
+ num_sampled = tl.load(num_sampled_ptr + req_id)
+ if num_sampled > 0:
+ token_id = tl.load(
+ sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
+ )
+ tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
+
+ query_start = tl.load(query_start_loc_ptr + req_id)
+ query_end = tl.load(query_start_loc_ptr + req_id + 1)
+ query_len = query_end - query_start
+ num_rejected = tl.load(num_rejected_ptr + req_id)
+
+ num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
+ num_computed += query_len - num_rejected
+ tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
+
+
+def post_update(
+ # [num_reqs]
+ idx_mapping: torch.Tensor,
+ # [max_num_reqs]
+ num_computed_tokens: torch.Tensor,
+ # [max_num_reqs]
+ last_sampled_tokens: torch.Tensor,
+ # [num_reqs, num_speculative_steps + 1]
+ sampled_tokens: torch.Tensor,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [num_reqs + 1]
+ query_start_loc: torch.Tensor,
+) -> None:
+ num_reqs = idx_mapping.shape[0]
+ _post_update_kernel[(num_reqs,)](
+ idx_mapping,
+ num_computed_tokens,
+ last_sampled_tokens,
+ sampled_tokens,
+ sampled_tokens.stride(0),
+ num_sampled,
+ num_rejected,
+ query_start_loc,
+ num_warps=1,
+ )
diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py
new file mode 100644
index 0000000000000..e34a45f979807
--- /dev/null
+++ b/vllm/v1/worker/gpu/model_runner.py
@@ -0,0 +1,1020 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import gc
+import time
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from vllm.config import VllmConfig
+from vllm.config.compilation import CUDAGraphMode
+from vllm.forward_context import set_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader import get_model_loader
+from vllm.utils.mem_constants import GiB_bytes
+from vllm.utils.mem_utils import DeviceMemoryProfiler
+from vllm.utils.platform_utils import is_pin_memory_available
+from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
+from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.outputs import (
+ EMPTY_MODEL_RUNNER_OUTPUT,
+ LogprobsTensors,
+ ModelRunnerOutput,
+)
+from vllm.v1.sample.sampler import SamplerOutput
+from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
+from vllm.v1.worker.gpu.attn_utils import (
+ build_attn_metadata,
+ get_kv_cache_spec,
+ init_attn_backend,
+ init_kv_cache,
+)
+from vllm.v1.worker.gpu.block_table import BlockTables
+from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
+from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
+from vllm.v1.worker.gpu.input_batch import (
+ InputBatch,
+ InputBuffers,
+ combine_sampled_and_draft_tokens,
+ post_update,
+ prepare_pos_seq_lens,
+ prepare_prefill_inputs,
+)
+from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
+from vllm.v1.worker.gpu.spec_decode import init_speculator
+from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
+ get_num_rejected,
+ rejection_sample,
+)
+from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
+from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
+from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
+from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
+
+logger = init_logger(__name__)
+
+
+class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ self.vllm_config = vllm_config
+ self.model_config = vllm_config.model_config
+ self.cache_config = vllm_config.cache_config
+ self.compilation_config = vllm_config.compilation_config
+ self.lora_config = vllm_config.lora_config
+ self.load_config = vllm_config.load_config
+ self.parallel_config = vllm_config.parallel_config
+ self.scheduler_config = vllm_config.scheduler_config
+ self.speculative_config = vllm_config.speculative_config
+ self.observability_config = vllm_config.observability_config
+
+ self.device = device
+ self.pin_memory = is_pin_memory_available()
+ self.dtype = self.model_config.dtype
+ self.kv_cache_dtype = self.dtype
+ if self.cache_config.cache_dtype != "auto":
+ # Quantized KV cache.
+ self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
+ self.cache_config.cache_dtype
+ ]
+ self.is_pooling_model = False
+
+ self.vocab_size = self.model_config.get_vocab_size()
+ self.max_model_len = self.model_config.max_model_len
+ self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.hidden_size = self.model_config.get_hidden_size()
+
+ self.dp_size = self.parallel_config.data_parallel_size
+ self.dp_rank = self.parallel_config.data_parallel_rank
+
+ self.use_async_scheduling = self.scheduler_config.async_scheduling
+ self.output_copy_stream = torch.cuda.Stream(self.device)
+ self.output_copy_event = torch.cuda.Event()
+ if self.use_async_scheduling:
+ self.input_prep_event = torch.cuda.Event()
+ self.structured_outputs_event = torch.cuda.Event()
+ self.spec_decode_event = torch.cuda.Event()
+ else:
+ self.input_prep_event = None
+ self.structured_outputs_event = None
+ self.spec_decode_event = None
+
+ if self.speculative_config is not None:
+ self.do_spec_decode = True
+ self.num_speculative_steps = self.speculative_config.num_speculative_tokens
+ self.speculator = init_speculator(self.vllm_config, self.device)
+ else:
+ self.do_spec_decode = False
+ self.num_speculative_steps = 0
+ self.speculator = None
+
+ self.req_states = RequestState(
+ max_num_reqs=self.max_num_reqs,
+ max_model_len=self.max_model_len,
+ max_num_batched_tokens=self.max_num_tokens,
+ num_speculative_steps=self.num_speculative_steps,
+ vocab_size=self.vocab_size,
+ device=self.device,
+ pin_memory=self.pin_memory,
+ )
+ self.input_buffers = InputBuffers(
+ max_num_reqs=self.max_num_reqs,
+ max_num_tokens=self.max_num_tokens,
+ hidden_size=self.hidden_size,
+ vocab_size=self.vocab_size,
+ dtype=self.dtype,
+ device=self.device,
+ pin_memory=self.pin_memory,
+ )
+ self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
+
+ # CUDA graphs.
+ self.cudagraph_manager = CudaGraphManager(
+ vllm_config=self.vllm_config,
+ device=self.device,
+ )
+
+ def get_supported_tasks(self) -> tuple[str]:
+ return ("generate",)
+
+ def load_model(self, *args, **kwargs) -> None:
+ time_before_load = time.perf_counter()
+ with DeviceMemoryProfiler() as m:
+ model_loader = get_model_loader(self.vllm_config.load_config)
+ logger.info("Loading model from scratch...")
+
+ self.model = model_loader.load_model(
+ vllm_config=self.vllm_config,
+ model_config=self.vllm_config.model_config,
+ )
+ if self.lora_config:
+ self.model = self.load_lora_model(
+ self.model,
+ self.vllm_config,
+ self.device,
+ )
+ if self.do_spec_decode:
+ self.speculator.load_model(self.model)
+ time_after_load = time.perf_counter()
+
+ self.model_memory_usage = m.consumed_memory
+ logger.info(
+ "Model loading took %.4f GiB and %.6f seconds",
+ m.consumed_memory / GiB_bytes,
+ time_after_load - time_before_load,
+ )
+
+ def get_model(self) -> nn.Module:
+ return self.model
+
+ def get_kv_cache_spec(self):
+ return get_kv_cache_spec(self.vllm_config)
+
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
+ kv_cache_config = deepcopy(kv_cache_config)
+ self.kv_cache_config = kv_cache_config
+ block_sizes = [
+ kv_cache_group.kv_cache_spec.block_size
+ for kv_cache_group in kv_cache_config.kv_cache_groups
+ ]
+
+ self.block_tables = BlockTables(
+ block_sizes=block_sizes,
+ max_num_reqs=self.max_num_reqs,
+ max_num_batched_tokens=self.max_num_tokens,
+ max_model_len=self.max_model_len,
+ device=self.device,
+ pin_memory=self.pin_memory,
+ )
+
+ self.attn_backends, self.attn_metadata_builders = init_attn_backend(
+ self.kv_cache_config,
+ self.vllm_config,
+ self.device,
+ )
+ # TODO(woosuk): Support other backends.
+ if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
+ raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
+
+ self.kv_caches: list[torch.Tensor] = []
+ init_kv_cache(
+ self.kv_caches,
+ self.compilation_config.static_forward_context,
+ self.kv_cache_config,
+ self.attn_backends,
+ self.device,
+ )
+ # Attention groups are not supported.
+ self.attn_groups = [] # type: ignore
+
+ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
+ block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
+ slot_mappings = self.block_tables.get_dummy_slot_mappings(
+ input_batch.num_tokens
+ )
+ num_computed_tokens = torch.zeros(
+ input_batch.num_reqs, dtype=torch.int32, device=self.device
+ )
+ attn_metadata = build_attn_metadata(
+ attn_metadata_builders=self.attn_metadata_builders,
+ num_reqs=input_batch.num_reqs,
+ num_tokens=input_batch.num_tokens,
+ query_start_loc=self.input_buffers.query_start_loc,
+ seq_lens=self.input_buffers.seq_lens,
+ seq_lens_np=input_batch.seq_lens_np,
+ num_computed_tokens_cpu=num_computed_tokens,
+ block_tables=block_tables,
+ slot_mappings=slot_mappings,
+ kv_cache_config=self.kv_cache_config,
+ )
+ input_batch.attn_metadata = attn_metadata
+
+ @torch.inference_mode()
+ def _dummy_run(
+ self,
+ num_tokens: int,
+ *args,
+ skip_attn: bool = True,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ num_reqs = min(num_tokens, self.max_num_reqs)
+ input_batch = InputBatch.make_dummy(
+ num_reqs=num_reqs,
+ num_tokens=num_tokens,
+ input_buffers=self.input_buffers,
+ device=self.device,
+ )
+ if not skip_attn:
+ self.prepare_dummy_attn_metadata(input_batch)
+
+ if self.dp_size == 1:
+ num_tokens_across_dp: torch.Tensor | None = None
+ else:
+ num_tokens_across_dp = torch.full(
+ (self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
+ )
+ num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
+ with (
+ self.maybe_dummy_run_with_lora(
+ self.lora_config,
+ input_batch.num_scheduled_tokens,
+ num_sampled_tokens,
+ ),
+ set_forward_context(
+ input_batch.attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ ),
+ ):
+ hidden_states = self.model(
+ input_ids=input_batch.input_ids,
+ positions=input_batch.positions,
+ )
+ sample_hidden_states = hidden_states[input_batch.logits_indices]
+ return hidden_states, sample_hidden_states
+
+ @torch.inference_mode()
+ def _dummy_sampler_run(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> None:
+ num_reqs = hidden_states.shape[0]
+ sampling_metadata = SamplingMetadata.make_dummy(
+ num_reqs=num_reqs,
+ device=self.device,
+ )
+ logits = self.model.compute_logits(hidden_states)
+ self.sampler(logits, sampling_metadata)
+
+ @torch.inference_mode()
+ def _dummy_speculator_run(
+ self,
+ hidden_states: torch.Tensor,
+ aux_hidden_states: list[torch.Tensor] | None,
+ ) -> None:
+ num_tokens = hidden_states.shape[0]
+ num_reqs = min(num_tokens, self.max_num_reqs)
+ input_batch = InputBatch.make_dummy(
+ num_reqs=num_reqs,
+ num_tokens=num_tokens,
+ input_buffers=self.input_buffers,
+ device=self.device,
+ )
+ sampling_metadata = SamplingMetadata.make_dummy(
+ num_reqs=num_reqs,
+ device=self.device,
+ )
+ num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
+ num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
+ self.propose_draft(
+ input_batch=input_batch,
+ sampling_metadata=sampling_metadata,
+ last_hidden_states=hidden_states,
+ aux_hidden_states=aux_hidden_states,
+ num_sampled=num_sampled,
+ num_rejected=num_rejected,
+ )
+
+ @torch.inference_mode()
+ def profile_run(self) -> None:
+ hidden_states, sample_hidden_states = self._dummy_run(
+ self.max_num_tokens,
+ skip_attn=True,
+ )
+ self._dummy_sampler_run(sample_hidden_states)
+ if self.do_spec_decode:
+ self._dummy_speculator_run(hidden_states, None)
+ torch.cuda.synchronize()
+ del hidden_states, sample_hidden_states
+ gc.collect()
+
+ def reset_mm_cache(self) -> None:
+ pass
+
+ def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
+ # SP is not supported yet.
+ return num_scheduled_tokens
+
+ @torch.inference_mode()
+ def capture_model(self) -> int:
+ if not self.cudagraph_manager.needs_capture():
+ logger.warning(
+ "Skipping CUDA graph capture. To turn on CUDA graph capture, "
+ "ensure `cudagraph_mode` was not manually set to `NONE`"
+ )
+ return 0
+
+ start_time = time.perf_counter()
+ gc.collect()
+ torch.cuda.empty_cache()
+ start_free_gpu_memory = torch.cuda.mem_get_info()[0]
+
+ with self.maybe_setup_dummy_loras(self.lora_config):
+ self.cudagraph_manager.capture(
+ model=self.model,
+ input_buffers=self.input_buffers,
+ block_tables=self.block_tables,
+ attn_metadata_builders=self.attn_metadata_builders,
+ kv_cache_config=self.kv_cache_config,
+ )
+
+ end_time = time.perf_counter()
+ end_free_gpu_memory = torch.cuda.mem_get_info()[0]
+ elapsed_time = end_time - start_time
+ cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
+ # This usually takes 5~20 seconds.
+ logger.info(
+ "Graph capturing finished in %.0f secs, took %.2f GiB",
+ elapsed_time,
+ cuda_graph_size / (1 << 30),
+ )
+ return cuda_graph_size
+
+ def warmup_for_prefill(self) -> None:
+ # For FlashInfer, we would like to execute a dummy prefill run
+ # to trigger JIT compilation.
+ if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
+ self._dummy_run(self.max_num_tokens, skip_attn=False)
+ torch.cuda.synchronize()
+
+ def update_states(self, scheduler_output: SchedulerOutput) -> None:
+ if scheduler_output.preempted_req_ids is not None:
+ for req_id in scheduler_output.preempted_req_ids:
+ self.req_states.remove_request(req_id)
+ for req_id in scheduler_output.finished_req_ids:
+ self.req_states.remove_request(req_id)
+
+ # TODO(woosuk): Change SchedulerOutput.
+ req_indices: list[int] = []
+ cu_num_new_blocks = tuple(
+ [0] for _ in range(self.block_tables.num_kv_cache_groups)
+ )
+ new_block_ids: tuple[list[int], ...] = tuple(
+ [] for _ in range(self.block_tables.num_kv_cache_groups)
+ )
+ overwrite: list[bool] = []
+
+ # Add new requests.
+ for new_req_data in scheduler_output.scheduled_new_reqs:
+ assert new_req_data.prompt_token_ids is not None
+ assert new_req_data.prefill_token_ids is not None
+ assert new_req_data.sampling_params is not None
+ req_id = new_req_data.req_id
+ self.req_states.add_request(
+ req_id=req_id,
+ prompt_len=len(new_req_data.prompt_token_ids),
+ prefill_token_ids=new_req_data.prefill_token_ids,
+ num_computed_tokens=new_req_data.num_computed_tokens,
+ sampling_params=new_req_data.sampling_params,
+ lora_request=new_req_data.lora_request,
+ )
+
+ req_index = self.req_states.req_id_to_index[req_id]
+ req_indices.append(req_index)
+ for i, block_ids in enumerate(new_req_data.block_ids):
+ x = cu_num_new_blocks[i][-1]
+ cu_num_new_blocks[i].append(x + len(block_ids))
+ new_block_ids[i].extend(block_ids)
+ overwrite.append(True)
+ # Update the GPU tensors for request states.
+ if scheduler_output.scheduled_new_reqs:
+ self.req_states.prefill_len.copy_to_gpu()
+
+ # Add new blocks for the existing requests.
+ cached_reqs = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(cached_reqs.req_ids):
+ req_index = self.req_states.req_id_to_index[req_id]
+
+ req_new_block_ids = cached_reqs.new_block_ids[i]
+ if req_new_block_ids is not None:
+ req_indices.append(req_index)
+ for group_id, block_ids in enumerate(req_new_block_ids):
+ x = cu_num_new_blocks[group_id][-1]
+ cu_num_new_blocks[group_id].append(x + len(block_ids))
+ new_block_ids[group_id].extend(block_ids)
+ overwrite.append(False)
+
+ if req_indices:
+ self.block_tables.append_block_ids(
+ req_indices=req_indices,
+ cu_num_new_blocks=cu_num_new_blocks,
+ new_block_ids=new_block_ids,
+ overwrite=overwrite,
+ )
+
+ def prepare_inputs(
+ self,
+ scheduler_output: SchedulerOutput,
+ num_tokens_after_padding: int,
+ ) -> InputBatch:
+ num_tokens = scheduler_output.total_num_scheduled_tokens
+ assert num_tokens > 0
+ num_reqs = len(scheduler_output.num_scheduled_tokens)
+
+ # Decode first, then prefill.
+ # batch_idx -> req_id
+ req_ids = sorted(
+ scheduler_output.num_scheduled_tokens.keys(),
+ key=lambda k: scheduler_output.num_scheduled_tokens[k],
+ )
+ num_scheduled_tokens = np.array(
+ [scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
+ )
+
+ idx_mapping_list = [
+ self.req_states.req_id_to_index[req_id] for req_id in req_ids
+ ]
+ idx_mapping = self.input_buffers.idx_mapping
+ idx_mapping.np[:num_reqs] = idx_mapping_list
+ idx_mapping_np = idx_mapping.np[:num_reqs]
+ idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
+
+ # Get the number of draft tokens for each request.
+ if not scheduler_output.scheduled_spec_decode_tokens:
+ # No draft token scheduled (common case).
+ total_num_draft_tokens = 0
+ total_num_logits = num_reqs
+ cu_num_logits = torch.arange(
+ num_reqs + 1, device=self.device, dtype=torch.int32
+ )
+ else:
+ draft_tokens = scheduler_output.scheduled_spec_decode_tokens
+ num_draft_tokens = np.array(
+ [
+ len(draft_tokens[req_id]) if req_id in draft_tokens else 0
+ for req_id in req_ids
+ ],
+ dtype=np.int32,
+ )
+ total_num_draft_tokens = int(num_draft_tokens.sum())
+ total_num_logits = num_reqs + total_num_draft_tokens
+
+ np.cumsum(
+ num_draft_tokens + 1,
+ out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
+ )
+ cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
+
+ # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
+ block_tables = self.block_tables.gather_block_tables(idx_mapping)
+
+ # Get query_start_loc.
+ np.cumsum(
+ num_scheduled_tokens,
+ out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
+ )
+ # Pad for full CUDA graph mode.
+ # Some attention backends like FA3 require query_start_loc to be non-decreasing.
+ self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
+ self.input_buffers.query_start_loc.copy_to_gpu()
+ query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
+ query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
+
+ # Copy prefill tokens from CPU to GPU.
+ prepare_prefill_inputs(
+ idx_mapping_np,
+ num_scheduled_tokens,
+ query_start_loc_np,
+ self.req_states.prefill_token_ids,
+ self.req_states.num_computed_prefill_tokens,
+ self.input_buffers.input_ids.np,
+ )
+ self.input_buffers.input_ids.copy_to_gpu(num_tokens)
+
+ # Prepare positions and seq_lens.
+ prepare_pos_seq_lens(
+ idx_mapping,
+ query_start_loc_gpu,
+ self.req_states.num_computed_tokens,
+ self.input_buffers.positions,
+ self.input_buffers.seq_lens,
+ )
+ seq_lens = self.input_buffers.seq_lens[:num_reqs]
+
+ # Some input token ids are directly read from the last sampled tokens
+ # and draft tokens. Also, get the logits indices to sample tokens from.
+ logits_indices = combine_sampled_and_draft_tokens(
+ self.input_buffers.input_ids.gpu,
+ idx_mapping,
+ self.req_states.last_sampled_tokens,
+ query_start_loc_gpu,
+ seq_lens,
+ self.req_states.prefill_len.gpu,
+ self.req_states.draft_tokens,
+ cu_num_logits,
+ total_num_logits,
+ )
+
+ # Compute slot mappings: [num_kv_cache_groups, num_tokens]
+ slot_mappings = self.block_tables.compute_slot_mappings(
+ query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
+ )
+
+ # Get num_computed_tokens.
+ # HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
+ # num_computed_tokens_cpu. This works for most cases.
+ num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping]
+ # HACK(woosuk): Only GPU has the exact seq_lens because at this point
+ # CPU does not know how many draft tokens are accepted/rejected in the
+ # previous step. Therefore, we use max_model_len to be safe.
+ # NOTE(woosuk): This only works for FA3 backend.
+ seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
+
+ # Layer name -> attention metadata.
+ attn_metadata = build_attn_metadata(
+ attn_metadata_builders=self.attn_metadata_builders,
+ num_reqs=num_reqs,
+ num_tokens=num_tokens,
+ query_start_loc=self.input_buffers.query_start_loc,
+ seq_lens=self.input_buffers.seq_lens,
+ seq_lens_np=seq_lens_np,
+ num_computed_tokens_cpu=num_computed_tokens,
+ block_tables=block_tables,
+ slot_mappings=slot_mappings,
+ kv_cache_config=self.kv_cache_config,
+ )
+
+ input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
+ positions = self.input_buffers.positions[:num_tokens_after_padding]
+ return InputBatch(
+ req_ids=req_ids,
+ num_reqs=num_reqs,
+ idx_mapping=idx_mapping,
+ idx_mapping_np=idx_mapping_np,
+ num_scheduled_tokens=num_scheduled_tokens,
+ num_tokens=num_tokens,
+ num_tokens_after_padding=num_tokens_after_padding,
+ num_draft_tokens=total_num_draft_tokens,
+ query_start_loc=query_start_loc_gpu,
+ query_start_loc_np=query_start_loc_np,
+ seq_lens=seq_lens,
+ seq_lens_np=seq_lens_np,
+ input_ids=input_ids,
+ positions=positions,
+ attn_metadata=attn_metadata,
+ logits_indices=logits_indices,
+ cu_num_logits=cu_num_logits,
+ )
+
+ def sample(
+ self,
+ hidden_states: torch.Tensor,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ grammar_output: GrammarOutput | None,
+ ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
+ sample_hidden_states = hidden_states[input_batch.logits_indices]
+ logits = self.model.compute_logits(sample_hidden_states)
+ if grammar_output is not None:
+ # Apply grammar bitmask to the logits in-place.
+ # TODO(woosuk): Make compatible with spec decoding.
+ assert input_batch.num_draft_tokens == 0
+ with async_barrier(self.structured_outputs_event):
+ apply_grammar_bitmask(
+ logits,
+ input_batch.req_ids,
+ grammar_output.structured_output_request_ids,
+ grammar_output.grammar_bitmask,
+ self.input_buffers,
+ )
+
+ # Sample tokens and compute logprobs (if needed).
+ sampler_output = self.sampler(logits, sampling_metadata)
+
+ # Get the number of sampled tokens.
+ prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
+ is_chunked_prefilling = input_batch.seq_lens < prefill_len
+ if input_batch.num_draft_tokens == 0:
+ # No draft tokens (common case).
+ # 0 if chunked-prefilling, 1 if not.
+ num_sampled = (~is_chunked_prefilling).int()
+ num_rejected = torch.zeros_like(num_sampled)
+ else:
+ # Draft tokens for spec decoding.
+ input_ids = input_batch.input_ids[input_batch.logits_indices]
+ sampled_tokens, num_sampled = rejection_sample(
+ sampler_output.sampled_token_ids,
+ input_ids,
+ input_batch.cu_num_logits,
+ self.num_speculative_steps,
+ )
+ num_sampled *= ~is_chunked_prefilling
+ num_rejected = get_num_rejected(
+ input_batch.cu_num_logits,
+ num_sampled,
+ )
+ sampler_output.sampled_token_ids = sampled_tokens
+ # TODO(woosuk): Support logprobs with spec decoding.
+ return sampler_output, num_sampled, num_rejected
+
+ def compute_prompt_logprobs(
+ self,
+ hidden_states: torch.Tensor,
+ input_batch: InputBatch,
+ ) -> dict[str, LogprobsTensors]:
+ idx_mapping_np = input_batch.idx_mapping_np
+ needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
+ if not np.any(needs_prompt_logprobs):
+ # No request asks for prompt logprobs.
+ return {}
+
+ prompt_lens = self.req_states.prompt_len[idx_mapping_np]
+ # NOTE(woosuk): -1 because the last prompt token's hidden state is not
+ # needed for prompt logprobs.
+ computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
+ includes_prompt = computed_prefill < prompt_lens - 1
+ # NOTE(woosuk): If the request was resumed after preemption, its prompt
+ # logprobs must have been computed before preemption. Skip.
+ resumed_after_prompt = (
+ prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
+ )
+ needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
+ if not np.any(needs_prompt_logprobs):
+ return {}
+
+ # Just to be safe, clone the input ids.
+ n = input_batch.num_tokens
+ # Shift the input ids by one.
+ token_ids = torch.empty_like(input_batch.input_ids[:n])
+ token_ids[: n - 1] = input_batch.input_ids[1:n]
+ # To avoid out-of-bound access, set the last token id to 0.
+ token_ids[n - 1] = 0
+
+ # Handle chunked prompts.
+ pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
+ is_prompt_chunked = pos_after_step < prompt_lens
+ prefill_token_ids = self.req_states.prefill_token_ids
+ query_start_loc = self.input_buffers.query_start_loc.np
+ for i, req_id in enumerate(input_batch.req_ids):
+ if not needs_prompt_logprobs[i]:
+ continue
+ if not is_prompt_chunked[i]:
+ continue
+ # The prompt is chunked. Get the next prompt token.
+ req_idx = input_batch.idx_mapping_np[i]
+ next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
+ idx = int(query_start_loc[i + 1] - 1)
+ # Set the next prompt token.
+ # NOTE(woosuk): This triggers a GPU operation.
+ token_ids[idx] = next_prompt_token
+
+ # NOTE(woosuk): We mask out logprobs for negative tokens.
+ prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
+ token_ids,
+ hidden_states[:n],
+ self.model.compute_logits,
+ )
+
+ prompt_token_ids = token_ids.unsqueeze(-1)
+ prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
+ for i, req_id in enumerate(input_batch.req_ids):
+ if not needs_prompt_logprobs[i]:
+ continue
+
+ start_idx = query_start_loc[i]
+ end_idx = query_start_loc[i + 1]
+ assert start_idx < end_idx, (
+ f"start_idx ({start_idx}) >= end_idx ({end_idx})"
+ )
+ logprobs = LogprobsTensors(
+ logprob_token_ids=prompt_token_ids[start_idx:end_idx],
+ logprobs=prompt_logprobs[start_idx:end_idx],
+ selected_token_ranks=prompt_ranks[start_idx:end_idx],
+ )
+
+ req_extra_data = self.req_states.extra_data[req_id]
+ prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
+ if is_prompt_chunked[i]:
+ # Prompt is chunked. Do not return the logprobs yet.
+ prompt_logprobs_list.append(logprobs)
+ continue
+
+ if prompt_logprobs_list:
+ # Merge the in-progress logprobs.
+ prompt_logprobs_list.append(logprobs)
+ logprobs = LogprobsTensors(
+ logprob_token_ids=torch.cat(
+ [x.logprob_token_ids for x in prompt_logprobs_list]
+ ),
+ logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
+ selected_token_ranks=torch.cat(
+ [x.selected_token_ranks for x in prompt_logprobs_list]
+ ),
+ )
+ prompt_logprobs_list.clear()
+
+ prompt_logprobs_dict[req_id] = logprobs
+ return prompt_logprobs_dict
+
+ def postprocess(
+ self,
+ input_batch: InputBatch,
+ sampled_tokens: torch.Tensor,
+ num_sampled: torch.Tensor,
+ num_rejected: torch.Tensor,
+ ) -> None:
+ # Update the number of computed tokens.
+ post_update(
+ input_batch.idx_mapping,
+ self.req_states.num_computed_tokens,
+ self.req_states.last_sampled_tokens,
+ sampled_tokens,
+ num_sampled,
+ num_rejected,
+ input_batch.query_start_loc,
+ )
+
+ # Update the number of computed prefill tokens.
+ idx_mapping_np = input_batch.idx_mapping_np
+ computed_prefill = self.req_states.num_computed_prefill_tokens
+ # TODO(woosuk): Simplify this.
+ computed_prefill[idx_mapping_np] = np.minimum(
+ computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens,
+ self.req_states.prefill_len.np[idx_mapping_np],
+ )
+
+ @torch.inference_mode()
+ def propose_draft(
+ self,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ last_hidden_states: torch.Tensor,
+ aux_hidden_states: list[torch.Tensor] | None,
+ num_sampled: torch.Tensor,
+ num_rejected: torch.Tensor,
+ ) -> torch.Tensor:
+ num_reqs = input_batch.num_reqs
+ idx_mapping_np = input_batch.idx_mapping_np
+ with async_barrier(self.spec_decode_event):
+ self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
+ self.req_states.prefill_token_ids[
+ idx_mapping_np,
+ self.req_states.num_computed_prefill_tokens[idx_mapping_np],
+ ]
+ )
+ next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
+ num_reqs
+ )
+
+ assert self.speculator is not None
+ draft_tokens = self.speculator.propose(
+ input_batch,
+ sampling_metadata,
+ last_hidden_states,
+ aux_hidden_states,
+ num_sampled,
+ num_rejected,
+ self.req_states.last_sampled_tokens,
+ next_prefill_tokens,
+ )
+ self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
+ return draft_tokens
+
+ def get_cudagraph_and_dp_padding(
+ self,
+ scheduler_output: SchedulerOutput,
+ ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ if self.dp_size == 1:
+ # No DP. Only consider CUDA graphs.
+ if total_num_scheduled_tokens == 0:
+ # Special case: no tokens to run.
+ return CUDAGraphMode.NONE, 0, None
+
+ cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
+ scheduler_output, total_num_scheduled_tokens
+ )
+ if cudagraph_size is not None:
+ # Use full CUDA graph.
+ return CUDAGraphMode.FULL, cudagraph_size, None
+ # Fall back to eager mode.
+ # TODO(woosuk): Support piecewise CUDA graphs.
+ return CUDAGraphMode.NONE, total_num_scheduled_tokens, None
+
+ # Consider DP padding and CUDA graph.
+ if total_num_scheduled_tokens == 0:
+ # Special handling is needed for 0.
+ cudagraph_size_before_dp: int | None = 0
+ else:
+ cudagraph_size_before_dp = self.cudagraph_manager.get_cudagraph_size(
+ scheduler_output, total_num_scheduled_tokens
+ )
+ if cudagraph_size_before_dp is None:
+ cudagraph_size_before_dp = -1
+
+ assert cudagraph_size_before_dp is not None
+ num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
+ total_num_scheduled_tokens,
+ cudagraph_size_before_dp,
+ self.dp_size,
+ self.dp_rank,
+ )
+ if all(cudagraph_size_across_dp >= 0):
+ # If all ranks can use CUDA graph, pad to the maximum number of tokens
+ # across DP and use CUDA graph.
+ num_tokens_after_padding = int(cudagraph_size_across_dp.max().item())
+ cudagraph_mode = CUDAGraphMode.FULL
+ else:
+ # If any of the ranks cannot use CUDA graph, use eager mode for all ranks.
+ # No padding is needed except for ranks that have no tokens to run.
+ num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
+ num_tokens_after_padding = num_tokens_across_dp[self.dp_rank]
+ cudagraph_mode = CUDAGraphMode.NONE
+ return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ scheduler_output: SchedulerOutput,
+ intermediate_tensors: Any | None = None,
+ dummy_run: bool = False,
+ ) -> ModelRunnerOutput | None:
+ assert intermediate_tensors is None
+ if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
+ # No need to run the model.
+ with async_barrier(self.input_prep_event):
+ self.update_states(scheduler_output)
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ # NOTE: Call this before the async barrier so CPU all-reduce and
+ # GPU execution can overlap.
+ cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
+ self.get_cudagraph_and_dp_padding(scheduler_output)
+ )
+ with async_barrier(self.input_prep_event):
+ self.update_states(scheduler_output)
+ if num_tokens_after_padding == 0:
+ # All DP ranks have zero tokens to run.
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ if not dummy_run:
+ # Common case.
+ # Prepare all the inputs and copy to the input buffers.
+ input_batch = self.prepare_inputs(
+ scheduler_output,
+ num_tokens_after_padding,
+ )
+
+ # NOTE(woosuk): Sampling metadata should be built under the async
+ # barrier to avoid race conditions.
+ pos = input_batch.positions[input_batch.logits_indices]
+ sampling_metadata = self.req_states.make_sampling_metadata(
+ input_batch.idx_mapping_np, pos
+ )
+ if input_batch.num_draft_tokens > 0:
+ sampling_metadata = self.req_states.expand_sampling_metadata(
+ sampling_metadata, input_batch.cu_num_logits
+ )
+
+ if self.lora_config:
+ # Activate LoRA adapters.
+ lora_inputs = self.req_states.make_lora_inputs(
+ input_batch.req_ids,
+ input_batch.idx_mapping_np,
+ input_batch.num_scheduled_tokens,
+ )
+ self._set_active_loras(*lora_inputs)
+ else:
+ # No actual tokens to run. A dummy run for DP.
+ num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
+ input_batch = InputBatch.make_dummy(
+ num_reqs=num_reqs,
+ num_tokens=num_tokens_after_padding,
+ input_buffers=self.input_buffers,
+ device=self.device,
+ )
+ self.prepare_dummy_attn_metadata(input_batch)
+ sampling_metadata = None
+
+ # Run model.
+ if cudagraph_mode == CUDAGraphMode.FULL:
+ # Run CUDA graph.
+ # NOTE(woosuk): Here, we don't need to pass the input tensors,
+ # because they are already copied to the CUDA graph input buffers.
+ hidden_states = self.cudagraph_manager.run(
+ input_batch.num_tokens_after_padding
+ )
+ else:
+ # Run PyTorch model in eager mode.
+ # TODO(woosuk): Support piecewise CUDA graph.
+ with set_forward_context(
+ input_batch.attn_metadata,
+ self.vllm_config,
+ num_tokens=input_batch.num_tokens_after_padding,
+ cudagraph_runtime_mode=cudagraph_mode,
+ num_tokens_across_dp=num_tokens_across_dp,
+ ):
+ hidden_states = self.model(
+ input_ids=input_batch.input_ids,
+ positions=input_batch.positions,
+ )
+
+ self.execute_model_state = hidden_states, input_batch, sampling_metadata
+ return None
+
+ @torch.inference_mode()
+ def sample_tokens(
+ self,
+ grammar_output: GrammarOutput | None,
+ ) -> AsyncOutput | ModelRunnerOutput:
+ assert self.execute_model_state is not None
+ hidden_states, input_batch, sampling_metadata = self.execute_model_state
+ self.execute_model_state = None # type: ignore
+ assert sampling_metadata is not None
+
+ sampler_output, num_sampled, num_rejected = self.sample(
+ hidden_states, input_batch, sampling_metadata, grammar_output
+ )
+ prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
+
+ # Prepare the model runner output.
+ model_runner_output = ModelRunnerOutput(
+ req_ids=input_batch.req_ids,
+ # NOTE(woosuk): req_id_to_index is unused in this model runner.
+ # Only for compatibility with the existing model runner and scheduler.
+ req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
+ sampled_token_ids=None, # type: ignore
+ logprobs=None,
+ prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
+ pooler_output=[],
+ kv_connector_output=None,
+ num_nans_in_logits=None,
+ )
+ async_output = AsyncOutput(
+ model_runner_output=model_runner_output,
+ sampler_output=sampler_output,
+ num_sampled_tokens=num_sampled,
+ copy_stream=self.output_copy_stream,
+ copy_event=self.output_copy_event,
+ )
+
+ # Postprocess results and update request states.
+ # NOTE: This is intentionally done after creating the AsyncOutput,
+ # ensuring that `copy_event` is recorded before calling postprocess.
+ # This sequencing may slightly reduce latency as async D2H copy does not
+ # need to wait for the postprocess to finish.
+ self.postprocess(
+ input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
+ )
+ if self.do_spec_decode:
+ _ = self.propose_draft(
+ input_batch,
+ sampling_metadata,
+ hidden_states,
+ None, # aux_hidden_states
+ num_sampled,
+ num_rejected,
+ )
+
+ if self.use_async_scheduling:
+ return async_output
+ return async_output.get_output()
diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py
new file mode 100644
index 0000000000000..d8676079ab951
--- /dev/null
+++ b/vllm/v1/worker/gpu/sampler.py
@@ -0,0 +1,330 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Callable
+
+import torch
+
+from vllm.config.model import LogprobsMode
+from vllm.triton_utils import tl, triton
+from vllm.v1.outputs import LogprobsTensors, SamplerOutput
+from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
+from vllm.v1.worker.gpu.states import SamplingMetadata
+
+
+class Sampler:
+ def __init__(
+ self,
+ logprobs_mode: LogprobsMode = "raw_logprobs",
+ ):
+ if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
+ raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
+ self.logprobs_mode = logprobs_mode
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> SamplerOutput:
+ if sampling_metadata.max_num_logprobs is not None:
+ if self.logprobs_mode == "processed_logprobs":
+ sampled, logits = self.sample(
+ logits, sampling_metadata, return_logits=True
+ )
+ else:
+ assert self.logprobs_mode == "raw_logprobs"
+ sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
+
+ logprobs_tensors = compute_topk_logprobs(
+ logits,
+ sampling_metadata.max_num_logprobs,
+ sampled,
+ )
+ else:
+ sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
+ logprobs_tensors = None
+
+ # These are GPU tensors.
+ sampler_output = SamplerOutput(
+ # The sampled tokens are expanded to 2D tensor with shape
+ # [num_requests, 1], where each row represents one generated
+ # token per request.
+ sampled_token_ids=sampled.view(-1, 1),
+ logprobs_tensors=logprobs_tensors,
+ )
+ return sampler_output
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ return_logits: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ is_greedy = sampling_metadata.temperature == 0
+ temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
+ logits = logits / temp.view(-1, 1)
+ logits = apply_top_k_top_p(
+ logits, sampling_metadata.top_k, sampling_metadata.top_p
+ )
+
+ sampled = gumbel_sample(
+ logits,
+ sampling_metadata.temperature,
+ sampling_metadata.seeds,
+ sampling_metadata.pos,
+ apply_temperature=False,
+ )
+ return sampled, logits if return_logits else None
+
+
+@triton.jit
+def _gumbel_sample_kernel(
+ local_argmax_ptr,
+ local_argmax_stride,
+ local_max_ptr,
+ local_max_stride,
+ logits_ptr,
+ logits_stride,
+ seeds_ptr,
+ pos_ptr,
+ temp_ptr,
+ vocab_size,
+ BLOCK_SIZE: tl.constexpr,
+ APPLY_TEMPERATURE: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+ block_idx = tl.program_id(1)
+ block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = block < vocab_size
+ logits = tl.load(
+ logits_ptr + req_idx * logits_stride + block,
+ mask=mask,
+ other=float("-inf"),
+ )
+ logits = logits.to(tl.float32)
+
+ temp = tl.load(temp_ptr + req_idx).to(tl.float32)
+ if temp != 0.0:
+ # Calculate the seed for gumbel noise.
+ seed = tl.load(seeds_ptr + req_idx)
+ pos = tl.load(pos_ptr + req_idx)
+ gumbel_seed = tl.randint(seed, pos)
+
+ # Generate gumbel noise.
+ r = tl.rand(gumbel_seed, block).to(tl.float64)
+ gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
+ gumbel_noise = gumbel_noise.to(tl.float32)
+
+ # Apply temperature.
+ if APPLY_TEMPERATURE:
+ # NOTE(woosuk): Use div_rn to match the behavior of torch.
+ logits = tl.div_rn(logits, temp)
+
+ # Apply gumbel noise.
+ logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
+
+ idx = tl.argmax(logits, axis=0)
+ token_id = block_idx * BLOCK_SIZE + idx
+ value = tl.max(logits, axis=0)
+ tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
+ tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
+
+
+def gumbel_sample(
+ logits: torch.Tensor, # [num_reqs, vocab_size]
+ temperature: torch.Tensor, # [num_reqs]
+ seed: torch.Tensor, # [num_reqs]
+ pos: torch.Tensor, # [num_reqs]
+ apply_temperature: bool,
+) -> torch.Tensor:
+ num_reqs, vocab_size = logits.shape
+ BLOCK_SIZE = 1024
+ num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
+ local_argmax = torch.empty(
+ num_reqs,
+ num_blocks,
+ dtype=torch.int64,
+ device=logits.device,
+ )
+ local_max = torch.empty(
+ num_reqs,
+ num_blocks,
+ dtype=torch.float32,
+ device=logits.device,
+ )
+ _gumbel_sample_kernel[(num_reqs, num_blocks)](
+ local_argmax,
+ local_argmax.stride(0),
+ local_max,
+ local_max.stride(0),
+ logits,
+ logits.stride(0),
+ seed,
+ pos,
+ temperature,
+ vocab_size,
+ BLOCK_SIZE=BLOCK_SIZE,
+ APPLY_TEMPERATURE=apply_temperature,
+ )
+ # NOTE(woosuk): Use int64 for later indexing.
+ max_block_idx = local_max.argmax(dim=-1, keepdim=True)
+ sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
+ return sampled
+
+
+@triton.jit
+def _topk_log_softmax_kernel(
+ output_ptr,
+ logits_ptr,
+ logits_stride,
+ topk_ids_ptr,
+ topk,
+ vocab_size,
+ BLOCK_SIZE: tl.constexpr,
+ PADDED_TOPK: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+ row_ptr = logits_ptr + req_idx * logits_stride
+
+ max_val = float("-inf")
+ for i in range(0, vocab_size, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
+ max_val = tl.max(tl.maximum(logits, max_val))
+ max_val = max_val.to(tl.float32) # type: ignore
+
+ se = 0.0
+ for i in range(0, vocab_size, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
+ # NOTE(woosuk): Make sure that logits and all following operations use FP32.
+ logits = logits.to(tl.float32)
+ e = tl.exp(logits - max_val)
+ e = tl.where(block < vocab_size, e, 0.0)
+ se += tl.sum(e)
+ lse = tl.log(se)
+
+ k_offset = tl.arange(0, PADDED_TOPK)
+ k_mask = k_offset < topk
+ topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
+
+ logits = tl.load(row_ptr + topk_ids, mask=k_mask)
+ logits = logits.to(tl.float32)
+ o = logits - max_val - lse
+ tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
+
+
+@triton.jit
+def _ranks_kernel(
+ output_ptr,
+ logits_ptr,
+ logits_stride,
+ token_ids_ptr,
+ vocab_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+ row_ptr = logits_ptr + req_idx * logits_stride
+
+ token_id = tl.load(token_ids_ptr + req_idx)
+ x = tl.load(row_ptr + token_id)
+
+ n = 0
+ for i in range(0, vocab_size, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
+ n += tl.sum((logits > x).to(tl.int32))
+ tl.store(output_ptr + req_idx, n)
+
+
+def compute_token_logprobs(
+ logits: torch.Tensor,
+ token_ids: torch.Tensor,
+) -> torch.Tensor:
+ batch_size = logits.shape[0]
+ vocab_size = logits.shape[1]
+ token_ids = token_ids.to(torch.int64)
+ num_logprobs = token_ids.shape[1]
+ logprobs = torch.empty(
+ batch_size,
+ num_logprobs,
+ dtype=torch.float32,
+ device=logits.device,
+ )
+ _topk_log_softmax_kernel[(batch_size,)](
+ logprobs,
+ logits,
+ logits.stride(0),
+ token_ids,
+ num_logprobs,
+ vocab_size,
+ BLOCK_SIZE=1024, # type: ignore
+ PADDED_TOPK=triton.next_power_of_2(num_logprobs),
+ )
+ return logprobs
+
+
+def compute_topk_logprobs(
+ logits: torch.Tensor,
+ num_logprobs: int,
+ sampled_token_ids: torch.Tensor,
+) -> LogprobsTensors:
+ assert num_logprobs >= 0
+ batch_size, vocab_size = logits.shape
+ if num_logprobs == 0:
+ logprob_token_ids = sampled_token_ids.unsqueeze(-1)
+ else:
+ topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
+ logprob_token_ids = torch.cat(
+ (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
+ )
+
+ # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
+ # logprobs tensor. Instead, we only compute and return the logprobs of
+ # the topk + 1 tokens.
+ logprobs = compute_token_logprobs(logits, logprob_token_ids)
+ token_ranks = torch.empty(
+ batch_size,
+ dtype=torch.int64,
+ device=logits.device,
+ )
+ _ranks_kernel[(batch_size,)](
+ token_ranks,
+ logits,
+ logits.stride(0),
+ sampled_token_ids,
+ vocab_size,
+ BLOCK_SIZE=8192, # type: ignore
+ )
+ return LogprobsTensors(
+ logprob_token_ids=logprob_token_ids,
+ logprobs=logprobs,
+ selected_token_ranks=token_ranks,
+ )
+
+
+def compute_prompt_logprobs(
+ prompt_token_ids: torch.Tensor,
+ prompt_hidden_states: torch.Tensor,
+ logits_fn: Callable[[torch.Tensor], torch.Tensor],
+) -> tuple[torch.Tensor, torch.Tensor]:
+ # Since materializing the full prompt logits can take too much memory,
+ # we compute it in chunks.
+ CHUNK_SIZE = 1024
+ logprobs = []
+ ranks = []
+ prompt_token_ids = prompt_token_ids.to(torch.int64)
+ for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
+ end_idx = start_idx + CHUNK_SIZE
+ # NOTE(woosuk): logits_fn can be slow because it involves all-gather.
+ prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
+ prompt_logprobs = compute_topk_logprobs(
+ prompt_logits,
+ 0, # num_logprobs
+ prompt_token_ids[start_idx:end_idx],
+ )
+ logprobs.append(prompt_logprobs.logprobs)
+ ranks.append(prompt_logprobs.selected_token_ranks)
+
+ logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
+ ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
+ return logprobs, ranks
diff --git a/vllm/v1/worker/gpu/spec_decode/__init__.py b/vllm/v1/worker/gpu/spec_decode/__init__.py
new file mode 100644
index 0000000000000..15b85204e05ce
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/__init__.py
@@ -0,0 +1,18 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+from vllm.config import VllmConfig
+
+
+def init_speculator(
+ vllm_config: VllmConfig,
+ device: torch.device,
+):
+ speculative_config = vllm_config.speculative_config
+ assert speculative_config is not None
+ if speculative_config.use_eagle():
+ from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
+
+ return EagleSpeculator(vllm_config, device)
+ raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py
new file mode 100644
index 0000000000000..3c8621cc69c97
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/eagle.py
@@ -0,0 +1,209 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+import torch.nn as nn
+
+from vllm.config import VllmConfig
+from vllm.config.compilation import CUDAGraphMode
+from vllm.forward_context import set_forward_context
+from vllm.model_executor.model_loader import get_model
+from vllm.triton_utils import tl, triton
+from vllm.v1.worker.gpu.input_batch import InputBatch
+from vllm.v1.worker.gpu.sampler import gumbel_sample
+from vllm.v1.worker.gpu.states import SamplingMetadata
+
+
+class EagleSpeculator:
+ def __init__(self, vllm_config: VllmConfig, device: torch.device):
+ self.vllm_config = vllm_config
+ self.device = device
+
+ self.speculative_config = vllm_config.speculative_config
+ assert self.speculative_config is not None
+ self.method = self.speculative_config.method
+ self.num_speculative_steps = self.speculative_config.num_speculative_tokens
+ self.draft_model_config = self.speculative_config.draft_model_config
+
+ self.scheduler_config = vllm_config.scheduler_config
+ self.max_num_reqs = self.scheduler_config.max_num_seqs
+ self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
+
+ self.input_ids = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device=device
+ )
+ self.positions = torch.zeros(
+ self.max_num_tokens, dtype=torch.int64, device=device
+ )
+
+ def load_model(self, target_model: nn.Module) -> None:
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("eagle_head"):
+ self.model = get_model(
+ vllm_config=self.vllm_config, model_config=self.draft_model_config
+ )
+
+ share_lm_head = True
+ if share_lm_head and hasattr(target_model, "lm_head"):
+ if hasattr(self.model, "lm_head"):
+ del self.model.lm_head
+ self.model.lm_head = target_model.lm_head
+
+ @torch.inference_mode()
+ def propose(
+ self,
+ input_batch: InputBatch,
+ sampling_metadata: SamplingMetadata,
+ # [num_tokens, hidden_size]
+ last_hidden_states: torch.Tensor,
+ # num_layers x [num_tokens, hidden_size]
+ aux_hidden_states: list[torch.Tensor] | None,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [max_num_reqs, 1]
+ last_sampled: torch.Tensor,
+ # [num_reqs]
+ next_prefill_tokens: torch.Tensor,
+ ) -> torch.Tensor:
+ # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
+ # number of rejected tokens, we maintain the size of eagle's input_ids and
+ # hidden_states the same as the target model's. This means, we pad each
+ # request's query length to include any rejected positions. By doing so,
+ # we can also reuse the attention metadata (e.g., query_start_loc,
+ # seq_lens) of the target model.
+ if aux_hidden_states:
+ assert self.method == "eagle3"
+ hidden_states = self.model.combine_hidden_states(
+ torch.cat(aux_hidden_states, dim=-1)
+ )
+ else:
+ hidden_states = last_hidden_states
+
+ # Get the input ids and last token indices for the speculator.
+ last_token_indices = prepare_eagle_inputs(
+ self.input_ids,
+ input_batch,
+ num_sampled,
+ num_rejected,
+ last_sampled,
+ next_prefill_tokens,
+ )
+ input_ids = self.input_ids[: input_batch.num_tokens_after_padding]
+
+ # Prefill: Run the eagle speculator with eager mode.
+ with set_forward_context(
+ input_batch.attn_metadata,
+ self.vllm_config,
+ num_tokens=input_batch.num_tokens_after_padding,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ ):
+ ret_hidden_states = self.model(
+ input_ids=input_ids,
+ positions=input_batch.positions,
+ hidden_states=hidden_states,
+ )
+ if self.method == "mtp":
+ last_hidden_states = ret_hidden_states
+ hidden_states = ret_hidden_states
+ else:
+ last_hidden_states, hidden_states = ret_hidden_states
+ sample_hidden_states = last_hidden_states[last_token_indices]
+ logits = self.model.compute_logits(sample_hidden_states)
+
+ num_reqs = input_batch.num_reqs
+ cu_num_logits = input_batch.cu_num_logits[:num_reqs]
+ temperature = sampling_metadata.temperature[cu_num_logits]
+ seed = sampling_metadata.seeds[cu_num_logits]
+ # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
+ # used for draft and target sampling.
+ pos = input_batch.positions[last_token_indices] + 1
+ # NOTE(woosuk): For draft sampling, we only consider the temperature
+ # and ignore the other sampling parameters such as top_k and top_p,
+ # for simplicity and performance.
+ # While this may slightly degrade the acceptance rate, it does not
+ # affect the output distribution after rejection sampling.
+ draft_tokens = gumbel_sample(
+ logits, temperature, seed, pos, apply_temperature=True
+ )
+ if self.num_speculative_steps == 1:
+ # Early exit.
+ return draft_tokens.view(-1, 1)
+ raise NotImplementedError("num_speculative_steps > 1 is not supported yet.")
+
+
+@triton.jit
+def _prepare_eagle_inputs_kernel(
+ last_token_indices_ptr,
+ eagle_input_ids_ptr,
+ target_input_ids_ptr,
+ idx_mapping_ptr,
+ last_sampled_ptr,
+ next_prefill_tokens_ptr,
+ num_sampled_ptr,
+ num_rejected_ptr,
+ query_start_loc_ptr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ batch_idx = tl.program_id(0)
+ query_start = tl.load(query_start_loc_ptr + batch_idx)
+ query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
+ query_len = query_end - query_start
+
+ # Get the true query length and next token after accounting for rejected tokens.
+ num_rejected = tl.load(num_rejected_ptr + batch_idx)
+ query_len -= num_rejected
+
+ num_sampled = tl.load(num_sampled_ptr + batch_idx)
+ if num_sampled > 0:
+ req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
+ next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
+ else:
+ # Chunked prefilling.
+ # Get the next prefill token.
+ next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
+
+ # Shift target_input_ids by one.
+ for i in range(1, query_len, BLOCK_SIZE):
+ block = i + tl.arange(0, BLOCK_SIZE)
+ mask = block < query_len
+ input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
+ tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
+
+ last_token_index = query_start + query_len - 1
+ tl.store(last_token_indices_ptr + batch_idx, last_token_index)
+ tl.store(eagle_input_ids_ptr + last_token_index, next_token)
+
+
+def prepare_eagle_inputs(
+ eagle_input_ids: torch.Tensor,
+ input_batch: InputBatch,
+ # [num_reqs]
+ num_sampled: torch.Tensor,
+ # [num_reqs]
+ num_rejected: torch.Tensor,
+ # [max_num_reqs, 1]
+ last_sampled: torch.Tensor,
+ # [max_num_reqs]
+ next_prefill_tokens: torch.Tensor,
+) -> torch.Tensor:
+ num_reqs = input_batch.num_reqs
+ last_token_indices = torch.empty(
+ num_reqs,
+ dtype=torch.int64,
+ device=eagle_input_ids.device,
+ )
+ _prepare_eagle_inputs_kernel[(num_reqs,)](
+ last_token_indices,
+ eagle_input_ids,
+ input_batch.input_ids,
+ input_batch.idx_mapping,
+ last_sampled,
+ next_prefill_tokens,
+ num_sampled,
+ num_rejected,
+ input_batch.query_start_loc,
+ BLOCK_SIZE=1024,
+ )
+ return last_token_indices
diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
new file mode 100644
index 0000000000000..43c6ac518bccc
--- /dev/null
+++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
@@ -0,0 +1,83 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+from vllm.triton_utils import tl, triton
+
+
+@triton.jit
+def _rejection_sample_kernel(
+ sampled_ptr, # [num_reqs, num_speculative_steps + 1]
+ sampled_stride,
+ num_sampled_ptr, # [num_reqs]
+ target_sampled_ptr, # [num_draft_tokens + num_reqs]
+ input_ids_ptr, # [num_draft_tokens + num_reqs]
+ cu_num_logits_ptr, # [num_reqs + 1]
+):
+ req_idx = tl.program_id(0)
+ start_idx = tl.load(cu_num_logits_ptr + req_idx)
+ end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
+ num_tokens = end_idx - start_idx
+
+ num_sampled = 0
+ rejected = False
+ for i in range(num_tokens - 1):
+ if not rejected:
+ target_sampled = tl.load(target_sampled_ptr + start_idx + i)
+ draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
+ tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
+ num_sampled += 1
+ if target_sampled != draft_sampled:
+ rejected = True
+ if not rejected:
+ target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
+ tl.store(
+ sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
+ )
+ num_sampled += 1
+ tl.store(num_sampled_ptr + req_idx, num_sampled)
+
+
+def rejection_sample(
+ # [num_draft_tokens + num_reqs]
+ target_sampled: torch.Tensor,
+ # [num_draft_tokens + num_reqs]
+ input_ids: torch.Tensor,
+ # [num_reqs + 1]
+ cu_num_logits: torch.Tensor,
+ num_speculative_steps: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ num_reqs = cu_num_logits.shape[0] - 1
+ sampled = torch.empty(
+ num_reqs,
+ num_speculative_steps + 1,
+ dtype=target_sampled.dtype,
+ device=target_sampled.device,
+ )
+ num_sampled = torch.empty(
+ num_reqs,
+ dtype=torch.int32,
+ device=target_sampled.device,
+ )
+ _rejection_sample_kernel[(num_reqs,)](
+ sampled,
+ sampled.stride(0),
+ num_sampled,
+ target_sampled,
+ input_ids,
+ cu_num_logits,
+ num_warps=1,
+ )
+ return sampled, num_sampled
+
+
+@torch.compile(dynamic=True)
+def get_num_rejected(
+ cu_num_logits: torch.Tensor,
+ num_sampled: torch.Tensor,
+) -> torch.Tensor:
+ num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
+ num_rejected = num_logits - num_sampled
+ # No token is rejected for chunked prefills.
+ num_rejected *= num_sampled > 0
+ return num_rejected
diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py
new file mode 100644
index 0000000000000..513d45d95d7cd
--- /dev/null
+++ b/vllm/v1/worker/gpu/states.py
@@ -0,0 +1,366 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from dataclasses import dataclass, field
+
+import numpy as np
+import torch
+
+from vllm.lora.request import LoRARequest
+from vllm.sampling_params import SamplingParams
+from vllm.triton_utils import tl, triton
+from vllm.v1.outputs import LogprobsTensors
+from vllm.v1.utils import CpuGpuBuffer
+
+_NP_INT64_MIN = np.iinfo(np.int64).min
+_NP_INT64_MAX = np.iinfo(np.int64).max
+NO_LORA_ID = 0
+
+
+@dataclass
+class SamplingMetadata:
+ temperature: torch.Tensor
+
+ top_p: torch.Tensor | None
+ top_k: torch.Tensor | None
+
+ seeds: torch.Tensor
+ pos: torch.Tensor
+
+ # None means no logprobs, 0 means sampled token logprobs only
+ max_num_logprobs: int | None
+
+ @classmethod
+ def make_dummy(
+ cls,
+ num_reqs: int,
+ device: torch.device,
+ ) -> "SamplingMetadata":
+ assert num_reqs > 0
+ temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
+ temperature[0] = 0.5
+ # TODO(woosuk): Use top-p and top-k for dummy sampler.
+ # Currently, they are disabled because of memory usage.
+ # top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
+ # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
+ top_p = None
+ top_k = None
+ seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
+ pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
+ max_num_logprobs = 20
+
+ return cls(
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ seeds=seeds,
+ pos=pos,
+ max_num_logprobs=max_num_logprobs,
+ )
+
+
+class RequestState:
+ def __init__(
+ self,
+ max_num_reqs: int,
+ max_model_len: int,
+ max_num_batched_tokens: int,
+ num_speculative_steps: int,
+ vocab_size: int,
+ device: torch.device,
+ pin_memory: bool,
+ ):
+ self.max_num_reqs = max_num_reqs
+ self.max_model_len = max_model_len
+ self.max_num_batched_tokens = max_num_batched_tokens
+ self.num_speculative_steps = num_speculative_steps
+ self.vocab_size = vocab_size
+ self.device = device
+ self.pin_memory = pin_memory
+
+ self.req_id_to_index: dict[str, int] = {}
+ self.index_to_req_id: dict[int, str] = {}
+ self.free_indices = list(range(max_num_reqs))
+ self.extra_data: dict[str, ExtraData] = {}
+
+ self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
+ self.prefill_token_ids = np.zeros(
+ (self.max_num_reqs, self.max_model_len),
+ dtype=np.int32,
+ )
+ self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
+
+ # Number of computed tokens.
+ self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
+ self.num_computed_tokens = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, device=device
+ )
+
+ # Last sampled tokens.
+ self.last_sampled_tokens = torch.zeros(
+ self.max_num_reqs,
+ 1,
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Draft tokens.
+ self.draft_tokens = torch.zeros(
+ self.max_num_reqs,
+ self.num_speculative_steps,
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # LoRA.
+ self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
+ self.lora_ids.fill(NO_LORA_ID)
+
+ # Sampling parameters.
+ self.temperature = self._make_param(self.max_num_reqs, torch.float32)
+ self.top_p = self._make_param(self.max_num_reqs, torch.float32)
+ self.top_k = self._make_param(self.max_num_reqs, torch.int32)
+ self.seeds = self._make_param(self.max_num_reqs, torch.int64)
+
+ self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
+ # -1 means no logprobs are requested.
+ self.num_logprobs.fill(-1)
+ self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
+
+ def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
+ return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
+
+ def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
+ return CpuGpuBuffer(
+ size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
+ )
+
+ @property
+ def num_reqs(self) -> int:
+ return len(self.req_id_to_index)
+
+ def add_request(
+ self,
+ req_id: str,
+ prompt_len: int,
+ prefill_token_ids: list[int],
+ num_computed_tokens: int,
+ sampling_params: SamplingParams,
+ lora_request: LoRARequest | None,
+ ) -> None:
+ assert len(self.free_indices) > 0, "No free indices"
+ req_idx = self.free_indices.pop()
+ self.req_id_to_index[req_id] = req_idx
+ self.index_to_req_id[req_idx] = req_id
+ self.extra_data[req_id] = ExtraData(lora_request)
+
+ self.prompt_len[req_idx] = prompt_len
+ prefill_len = len(prefill_token_ids)
+ assert prefill_len >= prompt_len, (
+ f"prefill_len {prefill_len} < prompt_len {prompt_len}"
+ )
+ self.prefill_len.np[req_idx] = prefill_len
+ self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
+
+ self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
+ # FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
+ # Optimize this.
+ self.num_computed_tokens[req_idx] = num_computed_tokens
+
+ if lora_request is not None:
+ self.lora_ids[req_idx] = lora_request.lora_int_id
+ else:
+ self.lora_ids[req_idx] = NO_LORA_ID
+
+ self.temperature.np[req_idx] = sampling_params.temperature
+ self.top_p.np[req_idx] = sampling_params.top_p
+ if 0 < sampling_params.top_k < self.vocab_size:
+ top_k = sampling_params.top_k
+ else:
+ top_k = self.vocab_size
+ self.top_k.np[req_idx] = top_k
+
+ if sampling_params.seed is not None:
+ seed = sampling_params.seed
+ else:
+ seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
+ self.seeds.np[req_idx] = seed
+
+ if sampling_params.logprobs is not None:
+ num_logprobs = sampling_params.logprobs
+ else:
+ num_logprobs = -1
+ self.num_logprobs[req_idx] = num_logprobs
+
+ # For now, only support prompt logprobs for the prompt tokens.
+ needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
+ self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
+
+ def remove_request(self, req_id: str) -> None:
+ self.extra_data.pop(req_id, None)
+ req_idx = self.req_id_to_index.pop(req_id, None)
+ if req_idx is None:
+ # Request not found.
+ return
+ self.index_to_req_id.pop(req_idx, None)
+ self.free_indices.append(req_idx)
+
+ def make_sampling_metadata(
+ self,
+ idx_mapping: np.ndarray,
+ pos: torch.Tensor,
+ ) -> SamplingMetadata:
+ temperature = self.temperature.np[idx_mapping]
+ temperature = self.temperature.copy_np_to_gpu(temperature)
+
+ top_p = self.top_p.np[idx_mapping]
+ no_top_p = np.all(top_p == 1.0)
+ top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
+
+ top_k = self.top_k.np[idx_mapping]
+ no_top_k = np.all(top_k == self.vocab_size)
+ top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
+
+ seeds = self.seeds.np[idx_mapping]
+ seeds = self.seeds.copy_np_to_gpu(seeds)
+
+ num_logprobs = self.num_logprobs[idx_mapping]
+ max_num_logprobs: int | None = int(np.max(num_logprobs))
+ if max_num_logprobs == -1:
+ max_num_logprobs = None
+
+ return SamplingMetadata(
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ seeds=seeds,
+ pos=pos,
+ max_num_logprobs=max_num_logprobs,
+ )
+
+ def expand_sampling_metadata(
+ self,
+ sampling_metadata: SamplingMetadata,
+ cu_num_logits: torch.Tensor,
+ ) -> SamplingMetadata:
+ # For draft tokens, we need to expand the sampling param tensors as
+ # each request samples multiple tokens in each step.
+ return expand_sampling_metadata(
+ sampling_metadata, cu_num_logits, self.num_speculative_steps
+ )
+
+ def make_lora_inputs(
+ self,
+ req_ids: list[str],
+ idx_mapping: np.ndarray,
+ num_scheduled_tokens: np.ndarray,
+ ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
+ lora_ids = self.lora_ids[idx_mapping]
+ prompt_lora_mapping = tuple(lora_ids)
+ token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
+
+ active_lora_requests: set[LoRARequest] = set()
+ for req_id in req_ids:
+ lora_request = self.extra_data[req_id].lora_request
+ if lora_request is not None:
+ active_lora_requests.add(lora_request)
+ return prompt_lora_mapping, token_lora_mapping, active_lora_requests
+
+
+class Param:
+ def __init__(
+ self,
+ size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ pin_memory: bool,
+ ):
+ self.buffer = CpuGpuBuffer(
+ size,
+ dtype=dtype,
+ device=device,
+ pin_memory=pin_memory,
+ )
+ self.np = np.zeros_like(self.buffer.np)
+
+ def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
+ n = x.shape[0]
+ self.buffer.np[:n] = x
+ return self.buffer.copy_to_gpu(n)
+
+
+@dataclass
+class ExtraData:
+ lora_request: LoRARequest | None
+ in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
+
+
+# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
+@triton.jit
+def _expand_sampling_metadata_kernel(
+ temp_ptr,
+ expanded_temp_ptr,
+ top_p_ptr,
+ expanded_top_p_ptr,
+ top_k_ptr,
+ expanded_top_k_ptr,
+ seeds_ptr,
+ expanded_seeds_ptr,
+ cu_num_logits_ptr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ req_idx = tl.program_id(0)
+ start_idx = tl.load(cu_num_logits_ptr + req_idx)
+ end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
+ num_tokens = end_idx - start_idx
+
+ block = tl.arange(0, BLOCK_SIZE)
+ mask = block < num_tokens
+
+ temp = tl.load(temp_ptr + req_idx)
+ tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
+
+ if top_p_ptr is not None:
+ top_p = tl.load(top_p_ptr + req_idx)
+ tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
+
+ if top_k_ptr is not None:
+ top_k = tl.load(top_k_ptr + req_idx)
+ tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
+
+ seed = tl.load(seeds_ptr + req_idx)
+ tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
+
+
+def expand_sampling_metadata(
+ sampling_metadata: SamplingMetadata,
+ cu_num_logits: torch.Tensor,
+ num_speculative_steps: int,
+) -> SamplingMetadata:
+ total_num_logits = sampling_metadata.pos.shape[0]
+ create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
+ expanded_temp = create_empty(sampling_metadata.temperature)
+ expanded_top_p = create_empty(sampling_metadata.top_p)
+ expanded_top_k = create_empty(sampling_metadata.top_k)
+ expanded_seeds = create_empty(sampling_metadata.seeds)
+
+ num_reqs = cu_num_logits.shape[0] - 1
+ _expand_sampling_metadata_kernel[(num_reqs,)](
+ sampling_metadata.temperature,
+ expanded_temp,
+ sampling_metadata.top_p,
+ expanded_top_p,
+ sampling_metadata.top_k,
+ expanded_top_k,
+ sampling_metadata.seeds,
+ expanded_seeds,
+ cu_num_logits,
+ BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
+ )
+ return SamplingMetadata(
+ temperature=expanded_temp,
+ top_p=expanded_top_p,
+ top_k=expanded_top_k,
+ seeds=expanded_seeds,
+ pos=sampling_metadata.pos,
+ max_num_logprobs=sampling_metadata.max_num_logprobs,
+ )
diff --git a/vllm/v1/worker/gpu/structured_outputs.py b/vllm/v1/worker/gpu/structured_outputs.py
new file mode 100644
index 0000000000000..83051b0ed33ff
--- /dev/null
+++ b/vllm/v1/worker/gpu/structured_outputs.py
@@ -0,0 +1,76 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import numpy as np
+import torch
+
+from vllm.triton_utils import tl, triton
+from vllm.v1.worker.gpu.input_batch import InputBuffers
+
+
+def apply_grammar_bitmask(
+ logits: torch.Tensor,
+ req_ids: list[str],
+ grammar_req_ids: list[str],
+ grammar_bitmask: np.ndarray,
+ input_buffers: InputBuffers,
+) -> None:
+ input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
+ input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
+
+ batch_size = logits.shape[0]
+ grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
+ # logits -> bitmask mapping
+ mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
+ input_buffers.bitmask_indices.np[:batch_size] = mapping
+ input_buffers.bitmask_indices.copy_to_gpu(batch_size)
+
+ vocab_size = logits.shape[-1]
+ BLOCK_SIZE = 8192
+ grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
+ _apply_grammar_bitmask_kernel[grid](
+ logits,
+ logits.stride(0),
+ input_buffers.grammar_bitmask.gpu,
+ input_buffers.grammar_bitmask.gpu.stride(0),
+ input_buffers.bitmask_indices.gpu,
+ vocab_size,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+
+# Adapted from
+# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
+@triton.jit
+def _apply_grammar_bitmask_kernel(
+ logits_ptr,
+ logits_stride,
+ bitmask_ptr,
+ bitmask_stride,
+ bitmask_indices_ptr,
+ vocab_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ logits_idx = tl.program_id(0)
+ bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
+ if bitmask_idx == -1:
+ # No bitmask to apply.
+ return
+
+ # Load the bitmask.
+ block_id = tl.program_id(1)
+ bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
+ packed_bitmask = tl.load(
+ bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
+ mask=bitmask_offset < bitmask_stride,
+ )
+ # Unpack the bitmask.
+ bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
+ bitmask = bitmask.reshape(BLOCK_SIZE)
+
+ # Apply the bitmask to the logits.
+ block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ tl.store(
+ logits_ptr + logits_idx * logits_stride + block_offset,
+ -float("inf"),
+ mask=bitmask & (block_offset < vocab_size),
+ )
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 023b5edb2c340..e7991baeaa1b8 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -43,6 +43,8 @@ class CachedRequestState:
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
+ xdrope_positions: torch.Tensor | None = None
+
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
@@ -87,7 +89,7 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
- dcp_kv_cache_interleave_size: int = 1,
+ cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@@ -141,7 +143,7 @@ class InputBatch:
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
- dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Sampling-related.
@@ -219,9 +221,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
- # NOTE(rob): num_prompt_logprobs only includes reqs
- # that are currently in the prefill phase.
- self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -251,7 +250,7 @@ class InputBatch:
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
- self.spec_token_ids: list[list[int] | None] = []
+ self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@@ -313,7 +312,7 @@ class InputBatch:
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
- self.spec_token_ids[req_index] = []
+ self.spec_token_ids[req_index].clear()
self.req_id_to_index[req_id] = req_index
@@ -385,12 +384,6 @@ class InputBatch:
if sampling_params.logprobs == -1
else sampling_params.logprobs
)
- if sampling_params.prompt_logprobs is not None:
- self.num_prompt_logprobs[req_id] = (
- self.vocab_size
- if sampling_params.prompt_logprobs == -1
- else sampling_params.prompt_logprobs
- )
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
@@ -462,7 +455,7 @@ class InputBatch:
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
- self.spec_token_ids[req_index] = None
+ self.spec_token_ids[req_index].clear()
# LoRA
lora_id = self.request_lora_mapping[req_index]
@@ -488,7 +481,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
- self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
@@ -535,7 +527,7 @@ class InputBatch:
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
- # instead, we need to temporiarily copy the data for one of the indices
+ # instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
@@ -654,9 +646,15 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
- spec_token_ids = self.spec_token_ids[last_req_index]
- self.spec_token_ids[empty_index] = spec_token_ids
- self.spec_token_ids[last_req_index] = None
+ if last_req_index != empty_index:
+ (
+ self.spec_token_ids[last_req_index],
+ self.spec_token_ids[empty_index],
+ ) = (
+ self.spec_token_ids[empty_index],
+ self.spec_token_ids[last_req_index],
+ )
+ self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
@@ -966,10 +964,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
- @property
- def no_prompt_logprob(self) -> bool:
- return not self.num_prompt_logprobs
-
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 506118d2d762b..74fd2a1e2a2c0 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -5,7 +5,7 @@ import gc
import itertools
import time
from collections import defaultdict
-from collections.abc import Iterator
+from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from copy import copy, deepcopy
from functools import reduce
@@ -50,15 +50,21 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+from vllm.model_executor.layers.rotary_embedding import (
+ MRotaryEmbedding,
+ XDRotaryEmbedding,
+)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
+ SupportsMRoPE,
SupportsMultiModal,
+ SupportsXDRoPE,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription,
+ supports_xdrope,
)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling,
@@ -126,6 +132,7 @@ from vllm.v1.outputs import (
)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
+from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
@@ -219,16 +226,14 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
del self._sampled_token_ids
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1:
- valid_sampled_token_ids: list[np.ndarray] = [
- row for row in self.sampled_token_ids_cpu.numpy()
- ]
+ valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
else:
valid_sampled_token_ids = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
)
for i in self._invalid_req_indices:
- valid_sampled_token_ids[i] = np.array([])
+ valid_sampled_token_ids[i].clear()
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
@@ -324,7 +329,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
- self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
+ self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@@ -350,6 +355,9 @@ class GPUModelRunner(
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
+ # Initialize in initialize_kv_cache_tensors
+ self.cross_layers_kv_cache: torch.Tensor | None = None
+ self.cross_layers_attn_backend: type[AttentionBackend] | None = None
# indexes: [kv_cache_group_id][attn_group]
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
@@ -373,7 +381,9 @@ class GPUModelRunner(
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
- self.use_aux_hidden_state_outputs = True
+ self.use_aux_hidden_state_outputs = (
+ self.drafter.eagle3_use_aux_hidden_state
+ )
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
@@ -391,6 +401,9 @@ class GPUModelRunner(
# Request states.
self.requests: dict[str, CachedRequestState] = {}
+ # NOTE(rob): num_prompt_logprobs only includes reqs
+ # that are currently in the prefill phase.
+ self.num_prompt_logprobs: dict[str, int] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
@@ -402,7 +415,10 @@ class GPUModelRunner(
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
- custom_logitsprocs = model_config.logits_processors
+ logits_processors = model_config.logits_processors
+ custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
+ tuple(logits_processors) if logits_processors is not None else ()
+ )
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer
@@ -426,7 +442,7 @@ class GPUModelRunner(
# uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
- dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
+ cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
@@ -502,6 +518,13 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ # Similar to mrope but use assigned dimension number for RoPE, 4 as default.
+ self.xdrope_positions = self._make_buffer(
+ (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
+ )
+
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
@@ -583,10 +606,14 @@ class GPUModelRunner(
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(
@@ -682,6 +709,7 @@ class GPUModelRunner(
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
+ self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -750,10 +778,21 @@ class GPUModelRunner(
)
self.requests[req_id] = req_state
+ if sampling_params and sampling_params.prompt_logprobs is not None:
+ self.num_prompt_logprobs[req_id] = (
+ self.input_batch.vocab_size
+ if sampling_params.prompt_logprobs == -1
+ else sampling_params.prompt_logprobs
+ )
+
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._init_xdrope_positions(req_state)
+
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
@@ -892,7 +931,8 @@ class GPUModelRunner(
# conform to the schema. This can result in
# scheduler_output.scheduled_spec_decode_tokens being empty,
# even when speculative decoding is enabled.
- self.input_batch.spec_token_ids[req_index] = spec_token_ids
+ self.input_batch.spec_token_ids[req_index].clear()
+ self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
# there are no draft tokens with async scheduling,
# we clear the spec_decoding info in scheduler_output and
@@ -956,14 +996,31 @@ class GPUModelRunner(
def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
assert supports_mrope(model), "M-RoPE support is not implemented."
+ assert req_state.prompt_token_ids is not None, (
+ "M-RoPE requires prompt_token_ids to be available."
+ )
+ mrope_model = cast(SupportsMRoPE, model)
req_state.mrope_positions, req_state.mrope_position_delta = (
- model.get_mrope_input_positions(
+ mrope_model.get_mrope_input_positions(
req_state.prompt_token_ids,
req_state.mm_features,
)
)
+ def _init_xdrope_positions(self, req_state: CachedRequestState):
+ model = self.get_model()
+ xdrope_model = cast(SupportsXDRoPE, model)
+ assert req_state.prompt_token_ids is not None, (
+ "XD-RoPE requires prompt_token_ids to be available."
+ )
+ assert supports_xdrope(model), "XD-RoPE support is not implemented."
+
+ req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
+ req_state.prompt_token_ids,
+ req_state.mm_features,
+ )
+
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
@@ -1208,6 +1265,11 @@ class GPUModelRunner(
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
+ # Calculate XD-RoPE positions.
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._calc_xdrope_positions(scheduler_output)
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -1341,6 +1403,12 @@ class GPUModelRunner(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
+ elif self.uses_xdrope_dim > 0:
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
+ self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True,
+ )
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
@@ -1435,7 +1503,7 @@ class GPUModelRunner(
self.seq_lens.cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
- self.parallel_config.dcp_kv_cache_interleave_size,
+ self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
@@ -1451,9 +1519,12 @@ class GPUModelRunner(
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
]
- dcp_local_seq_lens = (
- self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
- )
+
+ dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
+ if self.dcp_world_size > 1:
+ dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
+ dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
+
spec_decode_common_attn_metadata = None
if for_cudagraph_capture:
@@ -1521,6 +1592,7 @@ class GPUModelRunner(
causal=True,
encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=dcp_local_seq_lens,
+ dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
)
if self.speculative_config and spec_decode_common_attn_metadata is None:
@@ -1755,6 +1827,7 @@ class GPUModelRunner(
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len
+ assert req.mrope_position_delta is not None
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
out_offset=dst_start,
@@ -1765,6 +1838,53 @@ class GPUModelRunner(
mrope_pos_ptr += completion_part_len
+ def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
+ xdrope_pos_ptr = 0
+ for index, req_id in enumerate(self.input_batch.req_ids):
+ req = self.requests[req_id]
+ assert req.xdrope_positions is not None
+
+ num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
+ req.prompt_token_ids, req.prompt_embeds
+ )
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's xdrope_positions are pre-computed
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + prompt_part_len
+ src_start = num_computed_tokens
+ src_end = num_computed_tokens + prompt_part_len
+
+ self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
+ :, src_start:src_end
+ ]
+ xdrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's xdrope_positions on-the-fly
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + completion_part_len
+
+ XDRotaryEmbedding.get_next_input_positions_tensor(
+ out=self.xdrope_positions.np,
+ out_offset=dst_start,
+ context_len=num_computed_tokens + prompt_part_len,
+ num_new_tokens=completion_part_len,
+ )
+
+ xdrope_pos_ptr += completion_part_len
+
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
@@ -1900,20 +2020,24 @@ class GPUModelRunner(
for mm_input_id in encoder_input_ids:
mm_feature = req_state.mm_features[mm_input_id]
+ if mm_feature.data is None:
+ continue
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
return mm_kwargs, mm_hashes_pos
- def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
+ def _execute_mm_encoder(
+ self, scheduler_output: "SchedulerOutput"
+ ) -> list[torch.Tensor]:
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output
)
if not mm_kwargs:
- return
+ return []
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@@ -1923,7 +2047,7 @@ class GPUModelRunner(
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
model = cast(SupportsMultiModal, self.model)
- encoder_outputs = []
+ encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
@@ -1931,7 +2055,7 @@ class GPUModelRunner(
merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
):
- curr_group_outputs = []
+ curr_group_outputs: list[torch.Tensor] = []
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
@@ -1973,7 +2097,7 @@ class GPUModelRunner(
# 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items.
- curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
+ curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment]
sanity_check_mm_encoder_outputs(
curr_group_outputs,
@@ -1990,6 +2114,8 @@ class GPUModelRunner(
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
+ return encoder_outputs
+
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
@@ -2003,6 +2129,7 @@ class GPUModelRunner(
req_start_idx = 0
should_sync_mrope_positions = False
+ should_sync_xdrope_positions = False
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
@@ -2076,40 +2203,12 @@ class GPUModelRunner(
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+ if should_sync_xdrope_positions:
+ self._calc_xdrope_positions(scheduler_output)
+ self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+
return mm_embeds, is_mm_embed
- def _extract_encoder_inputs(
- self,
- scheduler_output: "SchedulerOutput",
- ) -> dict[str, torch.Tensor]:
- """Extract encoder inputs for encoder-decoder models.
-
- This method extracts multimodal input features from scheduled encoder
- inputs and formats them for the encoder-decoder model forward pass.
- """
- # Batch the multi-modal inputs using the helper method.
- mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
-
- if not mm_kwargs:
- return {}
-
- # Group MM kwargs by modality and extract features
- model = cast(SupportsMultiModal, self.model)
- encoder_features = {}
- for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
- mm_kwargs,
- device=self.device,
- pin_memory=self.pin_memory,
- merge_by_field_config=model.merge_by_field_config,
- multimodal_cpu_fields=model.multimodal_cpu_fields,
- ):
- # Add the grouped features to encoder_features dict
- # This allows the model to receive them as kwargs (e.g.,
- # input_features=...)
- encoder_features.update(mm_kwargs_group)
-
- return encoder_features
-
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
@@ -2173,7 +2272,7 @@ class GPUModelRunner(
def sync_and_slice_intermediate_tensors(
self,
num_tokens: int,
- intermediate_tensors: IntermediateTensors,
+ intermediate_tensors: IntermediateTensors | None,
sync_self: bool,
) -> IntermediateTensors:
assert self.intermediate_tensors is not None
@@ -2347,24 +2446,6 @@ class GPUModelRunner(
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
-
- # Generate custom attention masks for models that require them.
- # V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
- # Check mm_features (mm_embeds is empty during decode).
- has_mm_features = any(
- req_state.mm_features for req_state in self.requests.values()
- )
- if (
- self.uses_custom_attention_masks
- and has_mm_features
- and hasattr(self.model, "generate_attention_masks")
- ):
- mask_kwargs = self.model.generate_attention_masks(
- self.input_ids.gpu[:num_scheduled_tokens],
- self.positions.gpu[:num_scheduled_tokens],
- mask_dtype=self.model.dtype,
- )
- model_kwargs.update(mask_kwargs)
elif self.enable_prompt_embeds and is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
@@ -2400,14 +2481,18 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
+
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
if is_first_rank:
intermediate_tensors = None
else:
+ assert intermediate_tensors is not None
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True
)
@@ -2416,8 +2501,13 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder
and scheduler_output.scheduled_encoder_inputs
):
- encoder_inputs = self._extract_encoder_inputs(scheduler_output)
- model_kwargs.update(encoder_inputs)
+ # Run the encoder, just like we do with other multimodal inputs.
+ # For an encoder-decoder model, our processing here is a bit
+ # simpler, because the outputs are just passed to the decoder.
+ # We are not doing any prompt replacement. We also will only
+ # ever have a single encoder input.
+ encoder_outputs = self._execute_mm_encoder(scheduler_output)
+ model_kwargs.update({"encoder_outputs": encoder_outputs})
return (
input_ids,
@@ -2464,7 +2554,7 @@ class GPUModelRunner(
) -> tuple[
dict[str, int],
LogprobsLists | None,
- list[np.ndarray],
+ list[list[int]],
dict[str, LogprobsTensors | None],
list[str],
dict[str, int],
@@ -2489,8 +2579,9 @@ class GPUModelRunner(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
+ logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
- valid_sampled_token_ids: list[np.ndarray]
+ cu_num_new_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2503,9 +2594,15 @@ class GPUModelRunner(
sampled_token_ids,
self.input_batch.vocab_size,
)
+ if logprobs_tensors:
+ # Needed for extracting logprobs when spec decoding.
+ # This must be done prior to discarding sampled tokens.
+ cu_num_new_tokens = [0]
+ for toks in valid_sampled_token_ids:
+ cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
- valid_sampled_token_ids[int(i)] = np.array([])
+ valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
@@ -2530,29 +2627,15 @@ class GPUModelRunner(
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
- logprobs_tensors = sampler_output.logprobs_tensors
- cu_num_accepted_tokens = (
- [0] if spec_decode_metadata and logprobs_tensors else None
- )
for req_idx in range(num_sampled_tokens):
- sampled_ids: np.ndarray | None
if self.use_async_scheduling:
- sampled_ids = (
- np.array([-1]) if req_idx not in invalid_req_indices_set else None
- )
+ sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
- num_sampled_ids: int = (
- sampled_ids.shape[0] if sampled_ids is not None else 0
- )
+ num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
- if cu_num_accepted_tokens is not None:
- cu_num_accepted_tokens.append(
- cu_num_accepted_tokens[-1] + num_sampled_ids
- )
-
- if sampled_ids is None or num_sampled_ids == 0:
+ if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
@@ -2573,7 +2656,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
- logprobs_tensors.tolists(cu_num_accepted_tokens)
+ logprobs_tensors.tolists(cu_num_new_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
@@ -2701,7 +2784,7 @@ class GPUModelRunner(
scheduler_output, self.vllm_config
)
if self.cache_config.kv_sharing_fast_prefill:
- assert not self.input_batch.num_prompt_logprobs, (
+ assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
@@ -2776,14 +2859,14 @@ class GPUModelRunner(
uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
- batch_descriptor = BatchDescriptor(
+ batch_desc = BatchDescriptor(
num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
)
cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
- batch_descriptor,
+ batch_desc,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
)
@@ -2867,15 +2950,15 @@ class GPUModelRunner(
else:
logits = self.model.compute_logits(sample_hidden_states)
- model_output_broadcast_data = {}
+ model_output_broadcast_data: dict[str, Any] = {}
if logits is not None:
model_output_broadcast_data["logits"] = logits.contiguous()
- model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
+ broadcasted = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
)
- assert model_output_broadcast_data is not None
- logits = model_output_broadcast_data["logits"]
+ assert broadcasted is not None
+ logits = broadcasted["logits"]
self.execute_model_state = ExecuteModelState(
scheduler_output,
@@ -2900,7 +2983,7 @@ class GPUModelRunner(
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
if not kv_connector_output:
- return None # noqa
+ return None # type: ignore[return-value]
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
@@ -2936,9 +3019,7 @@ class GPUModelRunner(
self.input_batch.prev_sampled_token_ids = None
- def propose_draft_token_ids(
- sampled_token_ids: torch.Tensor | list[np.ndarray],
- ) -> None:
+ def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("gpu_model_runner: draft"):
self._draft_token_ids = self.propose_draft_token_ids(
@@ -2952,33 +3033,37 @@ class GPUModelRunner(
spec_decode_common_attn_metadata,
)
+ spec_config = self.speculative_config
use_padded_batch_for_eagle = (
- self.speculative_config
- and self.speculative_config.use_eagle()
- and not self.speculative_config.disable_padded_drafter_batch
+ spec_config is not None
+ and spec_config.use_eagle()
+ and not spec_config.disable_padded_drafter_batch
)
effective_drafter_max_model_len = self.max_model_len
if effective_drafter_max_model_len is None:
effective_drafter_max_model_len = self.model_config.max_model_len
if (
- self.speculative_config
- and self.speculative_config.draft_model_config is not None
- and self.speculative_config.draft_model_config.max_model_len is not None
+ spec_config is not None
+ and spec_config.draft_model_config is not None
+ and spec_config.draft_model_config.max_model_len is not None
):
effective_drafter_max_model_len = (
- self.speculative_config.draft_model_config.max_model_len
+ spec_config.draft_model_config.max_model_len
)
input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= effective_drafter_max_model_len
)
if use_padded_batch_for_eagle:
+ assert self.speculative_config is not None
+ assert isinstance(self.drafter, EagleProposer)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
+ assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata,
@@ -3107,16 +3192,18 @@ class GPUModelRunner(
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
- sampled_token_ids: torch.Tensor | list[np.ndarray],
+ sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
- ) -> torch.Tensor | list[list[int]]:
+ ) -> list[list[int]] | torch.Tensor:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
- if self.speculative_config.method == "ngram":
+ spec_config = self.speculative_config
+ assert spec_config is not None
+ if spec_config.method == "ngram":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose(
@@ -3126,11 +3213,11 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu,
self.input_batch.spec_decode_unsupported_reqs,
)
- elif self.speculative_config.method == "suffix":
+ elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer)
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
- elif self.speculative_config.method == "medusa":
+ elif spec_config.method == "medusa":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer)
@@ -3146,7 +3233,7 @@ class GPUModelRunner(
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, sampled_token_ids
):
- indices.append(offset + tokens.shape[0] - 1)
+ indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
@@ -3155,10 +3242,10 @@ class GPUModelRunner(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
- elif self.speculative_config.use_eagle():
+ elif spec_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
- if self.speculative_config.disable_padded_drafter_batch:
+ if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
@@ -3208,7 +3295,7 @@ class GPUModelRunner(
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
- if self.speculative_config.disable_padded_drafter_batch:
+ if spec_config.disable_padded_drafter_batch:
token_indices_to_sample = None
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
common_attn_metadata,
@@ -3303,9 +3390,12 @@ class GPUModelRunner(
and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb
):
+ spec_config = self.vllm_config.speculative_config
+ assert spec_config is not None
+ assert spec_config.draft_model_config is not None
logger.info_once(
"EPLB is enabled for drafter model %s.",
- self.vllm_config.speculative_config.draft_model_config.model,
+ spec_config.draft_model_config.model,
)
global_expert_load = (
@@ -3322,7 +3412,7 @@ class GPUModelRunner(
self.eplb_state = EplbState(self.parallel_config, self.device)
self.eplb_state.add_model(
self.drafter.model,
- self.vllm_config.speculative_config.draft_model_config,
+ spec_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
@@ -3357,9 +3447,11 @@ class GPUModelRunner(
scope="local",
)
prepare_communication_buffer_for_model(self.model)
+ mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model())
- and self.model_config.multimodal_config.is_multimodal_pruning_enabled()
+ and mm_config is not None
+ and mm_config.is_multimodal_pruning_enabled()
)
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
@@ -3380,6 +3472,8 @@ class GPUModelRunner(
old_global_expert_indices,
rank_mapping,
)
+ if self.eplb_state.is_async:
+ self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
if (
self.vllm_config.compilation_config.mode
@@ -3394,15 +3488,14 @@ class GPUModelRunner(
# CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
- if (
- self.compilation_config.cudagraph_mode.has_full_cudagraphs()
- and not self.parallel_config.enable_dbo
- ):
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+ assert cudagraph_mode is not None
+ if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
self.model = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
elif self.parallel_config.enable_dbo:
- if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
+ if cudagraph_mode.has_full_cudagraphs():
self.model = UBatchWrapper(
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
)
@@ -3458,7 +3551,7 @@ class GPUModelRunner(
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, LogprobsTensors | None]:
- num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
+ num_prompt_logprobs_dict = self.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
@@ -3469,7 +3562,10 @@ class GPUModelRunner(
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
- num_tokens = num_scheduled_tokens[req_id]
+ num_tokens = num_scheduled_tokens.get(req_id)
+ if num_tokens is None:
+ # This can happen if the request was preempted in prefill stage.
+ continue
# Get metadata for this request.
request = self.requests[req_id]
@@ -3650,6 +3746,7 @@ class GPUModelRunner(
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
+ is_graph_capturing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@@ -3748,6 +3845,31 @@ class GPUModelRunner(
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
+ # filter out the valid batch descriptor
+ _cg_mode, batch_descriptor = (
+ self.cudagraph_dispatcher.dispatch(
+ BatchDescriptor(
+ num_tokens=num_tokens_after_padding,
+ uniform_decode=uniform_decode,
+ has_lora=activate_lora and self.lora_config is not None,
+ )
+ )
+ if not is_profile
+ else (CUDAGraphMode.NONE, None)
+ )
+ if cudagraph_runtime_mode is not None:
+ # we allow forcing NONE when the dispatcher disagrees to support
+ # warm ups for cudagraph capture
+ assert (
+ cudagraph_runtime_mode == CUDAGraphMode.NONE
+ or cudagraph_runtime_mode == _cg_mode
+ ), (
+ f"Cudagraph runtime mode mismatch at dummy_run. "
+ f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
+ )
+ else:
+ cudagraph_runtime_mode = _cg_mode
+
attn_metadata: PerLayerAttnMetadata | None = None
# If force_attention is True, we always capture attention. Otherwise,
@@ -3803,6 +3925,8 @@ class GPUModelRunner(
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
else:
positions = self.positions.gpu[:num_tokens_after_padding]
@@ -3822,31 +3946,6 @@ class GPUModelRunner(
num_tokens_after_padding, None, False
)
- # filter out the valid batch descriptor
- _cg_mode, batch_descriptor = (
- self.cudagraph_dispatcher.dispatch(
- BatchDescriptor(
- num_tokens=num_tokens_after_padding,
- uniform_decode=uniform_decode,
- has_lora=activate_lora and self.lora_config is not None,
- )
- )
- if not is_profile
- else (CUDAGraphMode.NONE, None)
- )
- if cudagraph_runtime_mode is not None:
- # we allow forcing NONE when the dispatcher disagrees to support
- # warm ups for cudagraph capture
- assert (
- cudagraph_runtime_mode == CUDAGraphMode.NONE
- or cudagraph_runtime_mode == _cg_mode
- ), (
- f"Cudagraph runtime mode mismatch at dummy_run. "
- f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
- )
- else:
- cudagraph_runtime_mode = _cg_mode
-
if ubatch_slices is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
@@ -3883,7 +3982,7 @@ class GPUModelRunner(
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
use_cudagraphs = (
- cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
+ cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
@@ -3897,6 +3996,7 @@ class GPUModelRunner(
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
+ is_graph_capturing=is_graph_capturing,
)
# This is necessary to avoid blocking DP.
@@ -4082,7 +4182,8 @@ class GPUModelRunner(
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs:
- if self.model_config.multimodal_config.skip_mm_profiling:
+ mm_config = self.model_config.multimodal_config
+ if mm_config is not None and mm_config.skip_mm_profiling:
logger.info(
"Skipping memory profiling for multimodal encoder and "
"encoder cache."
@@ -4328,6 +4429,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
+ is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)
@@ -4344,8 +4446,9 @@ class GPUModelRunner(
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
+ layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(
- self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
+ self.vllm_config, layer_type, kv_cache_group_spec.layer_names
)
attn_backends = {}
attn_backend_layers = defaultdict(list)
@@ -4360,7 +4463,7 @@ class GPUModelRunner(
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
"FastPrefill",
- attn_backend,
+ attn_backend, # type: ignore[arg-type]
)
full_cls_name = attn_backend.full_cls_name()
@@ -4459,6 +4562,7 @@ class GPUModelRunner(
min_cg_backend_name = attn_backend.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
+ assert cudagraph_mode is not None
# check cudagraph for mixed batch is supported
if (
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
@@ -4573,12 +4677,17 @@ class GPUModelRunner(
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
)
- self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
+ capture_sizes = self.compilation_config.cudagraph_capture_sizes
+ self.cudagraph_batch_sizes = (
+ capture_sizes if capture_sizes is not None else []
+ )
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+ assert cudagraph_mode is not None
self.cudagraph_dispatcher.initialize_cudagraph_keys(
- self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
+ cudagraph_mode, self.uniform_decode_query_len
)
def calculate_reorder_batch_threshold(self) -> None:
@@ -4590,7 +4699,7 @@ class GPUModelRunner(
"""
min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)
- reorder_batch_thresholds = [
+ reorder_batch_thresholds: list[int | None] = [
group.get_metadata_builder().reorder_batch_threshold
for group in self._attn_group_iterator()
]
@@ -4599,7 +4708,7 @@ class GPUModelRunner(
if len(reorder_batch_thresholds) == 0:
self.reorder_batch_threshold = None
return
- self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
+ self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
@staticmethod
def select_common_block_size(
@@ -4631,7 +4740,7 @@ class GPUModelRunner(
"""
for backend in backends:
is_supported = False
- for supported_size in backend.supported_kernel_block_sizes:
+ for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
@@ -4662,7 +4771,7 @@ class GPUModelRunner(
all_int_supported_sizes = set(
supported_size
for backend in backends
- for supported_size in backend.supported_kernel_block_sizes
+ for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)
@@ -4944,12 +5053,30 @@ class GPUModelRunner(
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
- # Initialize the memory buffer for KV cache
- kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
- # Change the memory buffer to the desired shape
- kv_caches = self._reshape_kv_cache_tensors(
- kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
- )
+
+ # Try creating KV caches optimized for kv-connector transfers
+ cache_dtype = self.cache_config.cache_dtype
+ if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
+ kv_caches, cross_layers_kv_cache, attn_backend = (
+ self.allocate_uniform_kv_caches(
+ kv_cache_config,
+ self.attn_groups,
+ cache_dtype,
+ self.device,
+ kernel_block_sizes,
+ )
+ )
+ self.cross_layers_kv_cache = cross_layers_kv_cache
+ self.cross_layers_attn_backend = attn_backend
+ else:
+ # Fallback to the general case
+ # Initialize the memory buffer for KV cache
+ kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
+
+ # Change the memory buffer to the desired shape
+ kv_caches = self._reshape_kv_cache_tensors(
+ kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
+ )
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
@@ -5031,16 +5158,26 @@ class GPUModelRunner(
if has_kv_transfer_group():
kv_transfer_group = get_kv_transfer_group()
- kv_transfer_group.register_kv_caches(kv_caches)
+ if self.cross_layers_kv_cache is not None:
+ assert self.cross_layers_attn_backend is not None
+ kv_transfer_group.register_cross_layers_kv_cache(
+ self.cross_layers_kv_cache, self.cross_layers_attn_backend
+ )
+ else:
+ kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
if self.dcp_world_size > 1:
- layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
+ layer_type = cast(type[Any], AttentionLayerBase)
+ layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
for layer in layers.values():
- assert layer.impl.need_to_return_lse_for_decode, (
+ layer_impl = getattr(layer, "impl", None)
+ if layer_impl is None:
+ continue
+ assert layer_impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
- f"{layer.impl.__class__.__name__} "
+ f"{layer_impl.__class__.__name__} "
"does not return the softmax lse for decode."
)
@@ -5081,7 +5218,8 @@ class GPUModelRunner(
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
- attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
+ layer_type = cast(type[Any], AttentionLayerBase)
+ attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention) and (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
@@ -5101,7 +5239,7 @@ class GPUModelRunner(
return kv_cache_spec
- def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]:
+ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
# This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which
@@ -5114,4 +5252,4 @@ class GPUModelRunner(
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
- return [row for row in pinned.numpy()]
+ return pinned.tolist()
diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py
index 9de123263755b..2ce2b64512560 100644
--- a/vllm/v1/worker/gpu_ubatch_wrapper.py
+++ b/vllm/v1/worker/gpu_ubatch_wrapper.py
@@ -121,18 +121,24 @@ class UBatchWrapper:
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
- comm_sms = envs.VLLM_DBO_COMM_SMS
+ comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
- all2all_manager = get_ep_group().device_communicator.all2all_manager
+ ep_group = get_ep_group()
+ device_communicator = ep_group.device_communicator
+ all2all_manager = None
+ if device_communicator is not None:
+ all2all_manager = device_communicator.all2all_manager
- if all2all_manager.max_sms_used() is not None:
- comm_sms = min(comm_sms, all2all_manager.max_sms_used())
+ if all2all_manager is not None:
+ max_sms_used = all2all_manager.max_sms_used()
+ if max_sms_used is not None:
+ comm_sms = min(comm_sms, max_sms_used)
- if comm_sms > 0:
+ if comm_sms > 0 and all2all_manager is not None:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 315f01b68499a..6a4bfde5f972b 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -6,7 +6,7 @@ import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
import torch
import torch.distributed
@@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
+ get_pcp_group,
get_pp_group,
get_tp_group,
)
@@ -35,12 +36,12 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
-from vllm.profiler.gpu_profiler import CudaProfilerWrapper
+from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
-from vllm.v1.core.sched.output import GrammarOutput
+from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
@@ -57,7 +58,6 @@ logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
- from vllm.v1.core.sched.output import SchedulerOutput
class Worker(WorkerBase):
@@ -86,41 +86,22 @@ class Worker(WorkerBase):
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
- # Torch profiler. Enabled and configured through env vars:
+ # Torch/CUDA profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
+ # VLLM_TORCH_CUDA_PROFILE=1
+ self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR:
- torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
- logger.info(
- "Profiling enabled. Traces will be saved to: %s",
- torch_profiler_trace_dir,
- )
- logger.debug(
- "Profiler config: record_shapes=%s,"
- "profile_memory=%s,with_stack=%s,with_flops=%s",
- envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
- envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
- envs.VLLM_TORCH_PROFILER_WITH_STACK,
- envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
- )
- self.profiler = torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CPU,
- torch.profiler.ProfilerActivity.CUDA,
- ],
- record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
- profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
- with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
- with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
- on_trace_ready=torch.profiler.tensorboard_trace_handler(
- torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
- ),
+ self.profiler = TorchProfilerWrapper(
+ worker_name=worker_name, local_rank=self.local_rank
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
self.profiler = CudaProfilerWrapper()
else:
self.profiler = None
+ self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
+
def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
@@ -168,17 +149,17 @@ class Worker(WorkerBase):
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be used for one instance per process."
)
- context = allocator.use_memory_pool(tag=tag)
+ return allocator.use_memory_pool(tag=tag)
else:
- context = nullcontext()
- return context
+ return nullcontext()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
- if self.device_config.device.type == "cuda":
+ device = self.device_config.device
+ if isinstance(device, torch.device) and device.type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if (
@@ -204,14 +185,14 @@ class Worker(WorkerBase):
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
- visible_device_count = (
- torch.cuda.device_count() if torch.cuda.is_available() else 0
- )
- assert self.parallel_config.local_world_size <= visible_device_count, (
- f"local_world_size ({self.parallel_config.local_world_size}) must be "
- f"less than or equal to the number of visible devices "
- f"({visible_device_count})."
- )
+ visible_device_count = (
+ torch.cuda.device_count() if torch.cuda.is_available() else 0
+ )
+ assert self.parallel_config.local_world_size <= visible_device_count, (
+ f"local_world_size ({self.parallel_config.local_world_size}) must "
+ f"be less than or equal to the number of visible devices "
+ f"({visible_device_count})."
+ )
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)
@@ -257,9 +238,17 @@ class Worker(WorkerBase):
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Construct the model runner
- self.model_runner: GPUModelRunner = GPUModelRunner(
- self.vllm_config, self.device
- )
+ if self.use_v2_model_runner:
+ from vllm.v1.worker.gpu.model_runner import (
+ GPUModelRunner as GPUModelRunnerV2,
+ )
+
+ # HACK(woosuk): This is a temporary fix to avoid type errors.
+ self.model_runner: GPUModelRunner = GPUModelRunnerV2( # type: ignore
+ self.vllm_config, self.device
+ )
+ else:
+ self.model_runner = GPUModelRunner(self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
@@ -397,23 +386,21 @@ class Worker(WorkerBase):
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
- context = allocator.use_memory_pool(tag="kv_cache")
+ with allocator.use_memory_pool(tag="kv_cache"):
+ self.model_runner.initialize_kv_cache(kv_cache_config)
else:
- context = nullcontext()
- with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
- warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
+ compile_sizes = self.vllm_config.compilation_config.compile_sizes
+ warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
if not self.model_config.enforce_eager:
- warmup_sizes = [
- x
- for x in warmup_sizes
- if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
- ]
+ capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
+ if capture_sizes is not None:
+ warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes]
# We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
@@ -525,10 +512,12 @@ class Worker(WorkerBase):
if not self.profiler:
return nullcontext()
+ self.profiler.step()
+
num_new = len(scheduler_output.scheduled_new_reqs)
num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)
- return torch.profiler.record_function(
+ return self.profiler.annotate_context_manager(
f"execute_new_{num_new}_cached_{num_cached}"
)
@@ -552,12 +541,12 @@ class Worker(WorkerBase):
)
}
if forward_pass and not get_pp_group().is_first_rank:
- intermediate_tensors = IntermediateTensors(
- get_pp_group().recv_tensor_dict(
- all_gather_group=get_tp_group(),
- all_gather_tensors=all_gather_tensors,
- )
+ tensor_dict = get_pp_group().recv_tensor_dict(
+ all_gather_group=get_tp_group(),
+ all_gather_tensors=all_gather_tensors,
)
+ assert tensor_dict is not None
+ intermediate_tensors = IntermediateTensors(tensor_dict)
with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model(
@@ -586,27 +575,19 @@ class Worker(WorkerBase):
def profile(self, is_start: bool = True):
if self.profiler is None:
- raise RuntimeError("Profiler is not enabled.")
+ raise RuntimeError("Profiling is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
- if isinstance(self.profiler, torch.profiler.profile):
- rank = self.local_rank
- profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
- profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
- sort_key = "self_cuda_time_total"
- table = self.profiler.key_averages().table(sort_by=sort_key)
-
- with open(profiler_out_file, "w") as f:
- print(table, file=f)
-
- # only print profiler results on rank 0
- if rank == 0:
- print(table)
def execute_dummy_batch(self) -> None:
- self.model_runner._dummy_run(1, uniform_decode=True)
+ if self.use_v2_model_runner:
+ self.model_runner.execute_model(
+ SchedulerOutput.make_empty(), dummy_run=True
+ )
+ else:
+ self.model_runner._dummy_run(1, uniform_decode=True)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
@@ -638,7 +619,7 @@ class Worker(WorkerBase):
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
- global_expert_load=None,
+ global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
@@ -694,7 +675,7 @@ class Worker(WorkerBase):
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
- ) -> torch.Tensor | None:
+ ) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
@@ -733,6 +714,7 @@ class Worker(WorkerBase):
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
+ pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
@@ -760,26 +742,29 @@ class Worker(WorkerBase):
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
- self.model_runner.eplb_state.physical_to_logical_map.shape[1]
+ self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- - self.model_runner.eplb_state.logical_replica_count.shape[1]
+ - self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
- num_local_physical_experts = torch.tensor(
+ num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
- num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
+ num_local_physical_experts_tensor,
+ group=get_ep_group().cpu_group,
+ group_src=0,
)
- num_local_physical_experts = num_local_physical_experts.item()
+ num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
- global_expert_loads = self.model_runner.eplb_state.rearrange(
+ global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
+ global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
@@ -863,6 +848,8 @@ class Worker(WorkerBase):
def shutdown(self) -> None:
if runner := getattr(self, "model_runner", None):
runner.ensure_kv_transfer_shutdown()
+ if self.profiler is not None:
+ self.profiler.shutdown()
def init_worker_distributed_environment(
@@ -879,13 +866,15 @@ def init_worker_distributed_environment(
init_batch_invariance()
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
+ init_method = distributed_init_method or "env://"
init_distributed_environment(
- parallel_config.world_size, rank, distributed_init_method, local_rank, backend
+ parallel_config.world_size, rank, init_method, local_rank, backend
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
+ parallel_config.prefill_context_parallel_size,
parallel_config.decode_context_parallel_size,
)
diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py
index db037a9fccd5c..ff047d8d03f0e 100644
--- a/vllm/v1/worker/kv_connector_model_runner_mixin.py
+++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py
@@ -11,7 +11,11 @@ from typing import (
TYPE_CHECKING, # noqa: UP035
)
+import torch
+
+from vllm.attention import AttentionBackend
from vllm.config import VllmConfig
+from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
@@ -21,11 +25,13 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
+from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
+from vllm.v1.worker.utils import AttentionGroup
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -53,7 +59,7 @@ class KVConnectorModelRunnerMixin:
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
- if has_kv_transfer_group and has_kv_transfer_group():
+ if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
ensure_kv_transfer_shutdown()
@staticmethod
@@ -142,3 +148,162 @@ class KVConnectorModelRunnerMixin:
if has_kv_transfer_group():
return get_kv_transfer_group().get_kv_connector_stats()
return None
+
+ @staticmethod
+ def use_uniform_kv_cache(
+ attn_groups: list[list[AttentionGroup]],
+ cache_dtype: CacheDType,
+ ) -> bool:
+ """
+ Determines whether a uniform KV layout should be used.
+ A uniform layout means all layers KV caches will share the same
+ underlying tensor, where for a given block number, the respective
+ KV data for all layers will be contiguous.
+ This will allow efficient KV transfer of per-block KV data for all
+ layers at once.
+ Note this layout will only be applied given 3 conditions:
+ 1. The KV Cache config contains just a single group where all layers
+ have the same page size.
+ 2. A KV connector is configured, and the KV connector instance prefers
+ to use this layout (prefer_cross_layer_blocks() returns True)
+ 2. The flash attention backend supports this layout
+ (get_kv_cache_stride_order(True) includes a placement for a
+ num_layers dimension)
+
+ Note that the actual placement of the num_layers dimensions
+ in the unified layers tensors will be determined by the attention
+ backend.
+ Thus, the layers KV data may still not be contiguous per block
+ if the attention backend does not support it.
+
+ Args:
+ attn_groups: The list of attention groups for this model
+ cache_dtype: The KV cache dtype
+ Returns:
+ True if we should use a uniform KV cache layout.
+ """
+
+ if not has_kv_transfer_group():
+ return False
+ if not get_kv_transfer_group().prefer_cross_layer_blocks:
+ return False
+
+ if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
+ return False
+
+ attn_group = attn_groups[0][0]
+ kv_cache_spec = attn_group.kv_cache_spec
+ if not isinstance(kv_cache_spec, AttentionSpec):
+ return False
+
+ attn_backend = attn_group.backend
+ kv_cache_shape = attn_backend.get_kv_cache_shape(
+ 1234,
+ kv_cache_spec.block_size,
+ kv_cache_spec.num_kv_heads,
+ kv_cache_spec.head_size,
+ cache_dtype_str=cache_dtype,
+ )
+
+ try:
+ kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
+ include_num_layers_dimension=True
+ )
+ except (AttributeError, NotImplementedError):
+ return False
+
+ # check that attention backend include a layers dimension
+ return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
+
+ @staticmethod
+ def allocate_uniform_kv_caches(
+ kv_cache_config: KVCacheConfig,
+ attn_groups: list[list[AttentionGroup]],
+ cache_dtype: CacheDType,
+ device: torch.device,
+ kernel_block_sizes: list[int],
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
+ """
+ Initializes and reshapes KV caches for the simple case where all
+ layers have the same layout.
+
+ This function assumes use_uniform_kv_cache() returned True.
+
+ Args:
+ kv_cache_config: The KV cache config
+ attn_groups: The list of attention groups for this model
+ cache_dtype: The KV cache dtype
+ device: The torch device to allocate on.
+ kernel_block_sizes: The kernel block sizes for each KV cache group.
+ Returns:
+ A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
+ kv_caches is a dict mapping between layer names to their
+ corresponding memory buffer for KV cache.
+ cross_layers_kv_cache is the cross layers kv cache tensor
+ attn_backend is the attention backend matching this tensor
+ """
+ attn_group = attn_groups[0][0]
+ kv_cache_spec = attn_group.kv_cache_spec
+ assert isinstance(kv_cache_spec, AttentionSpec)
+
+ tensor_sizes = set(
+ kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
+ )
+ assert len(tensor_sizes) == 1
+ tensor_size = tensor_sizes.pop()
+
+ page_size = kv_cache_spec.page_size_bytes
+ assert tensor_size % page_size == 0
+ num_blocks = tensor_size // page_size
+ num_layers = len(kv_cache_config.kv_cache_tensors)
+ total_size = tensor_size * num_layers
+
+ assert len(kernel_block_sizes) == 1
+ kernel_block_size = kernel_block_sizes[0]
+ num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
+ kernel_num_blocks = num_blocks * num_blocks_per_kv_block
+
+ attn_backend = attn_group.backend
+ kv_cache_shape = attn_backend.get_kv_cache_shape(
+ kernel_num_blocks,
+ kernel_block_size,
+ kv_cache_spec.num_kv_heads,
+ kv_cache_spec.head_size,
+ cache_dtype_str=cache_dtype,
+ )
+
+ # prepend a num_layers dimension into the shape
+ kv_cache_shape = (num_layers,) + kv_cache_shape
+
+ try:
+ kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
+ include_num_layers_dimension=True
+ )
+ assert len(kv_cache_stride_order) == len(kv_cache_shape)
+ except (AttributeError, NotImplementedError):
+ kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
+
+ kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
+
+ logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
+
+ # allocate one contiguous buffer for all layers
+ cross_layers_kv_cache = (
+ torch.zeros(total_size, dtype=torch.int8, device=device)
+ .view(kv_cache_spec.dtype)
+ .view(kv_cache_shape)
+ )
+
+ # Maintain original KV shape view.
+ inv_order = [
+ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
+ ]
+ permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
+
+ kv_caches = {}
+ for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
+ tensor = permuted_kv_cache[i]
+ for layer_name in kv_cache_tensor.shared_by:
+ kv_caches[layer_name] = tensor
+
+ return kv_caches, cross_layers_kv_cache, attn_backend
diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py
index 6bf4f91931849..2ed65ca9d31cd 100644
--- a/vllm/v1/worker/tpu_input_batch.py
+++ b/vllm/v1/worker/tpu_input_batch.py
@@ -149,9 +149,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
- # NOTE(rob): num_prompt_logprobs only includes reqs
- # that are currently in the prefill phase.
- self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -256,8 +253,6 @@ class InputBatch:
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
- if sampling_params.prompt_logprobs is not None:
- self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
@@ -317,7 +312,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
- self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
@@ -584,10 +578,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
- @property
- def no_prompt_logprob(self) -> bool:
- return not self.num_prompt_logprobs
-
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index e9eb7cad38f88..72d4474b89627 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -219,9 +219,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size()
self.vocab_size = model_config.get_vocab_size()
- if self.lora_config is not None:
- self.vocab_size += self.lora_config.lora_extra_vocab_size
-
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
@@ -250,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
+ # NOTE(rob): num_prompt_logprobs only includes reqs
+ # that are currently in the prefill phase.
+ self.num_prompt_logprobs: dict[str, int] = {}
# Initialize input batch early to avoid AttributeError in _update_states
self.input_batch = InputBatch(
@@ -423,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
+ self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -480,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request=new_req_data.lora_request,
)
+ if sampling_params and sampling_params.prompt_logprobs is not None:
+ self.num_prompt_logprobs[req_id] = (
+ self.input_batch.vocab_size
+ if sampling_params.prompt_logprobs == -1
+ else sampling_params.prompt_logprobs
+ )
+
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
@@ -575,7 +583,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
- layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
+ layers = get_layers_from_vllm_config(
+ self.vllm_config,
+ AttentionLayerBase, # type: ignore[type-abstract]
+ )
block_size = self.vllm_config.cache_config.block_size
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
@@ -728,7 +739,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_id = self.input_batch.req_ids[i]
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
- if not use_max_model_len and num_tokens > self.most_model_len:
+ if (
+ not use_max_model_len
+ and self.most_model_len is not None
+ and num_tokens > self.most_model_len
+ ):
use_max_model_len = True
num_scheduled_tokens_per_req.append(num_tokens)
if use_max_model_len:
@@ -740,6 +755,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
end_index = num_reqs
else:
+ assert self.num_reqs_most_model_len is not None
if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
: self.num_reqs_most_model_len
@@ -832,6 +848,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
].to(self.device)
seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
else:
+ assert self.num_reqs_most_model_len is not None
block_tables = self.block_table_cpu[
: self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
]
@@ -934,6 +951,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for mm_input_id in encoder_input_ids:
mm_feature = req_state.mm_features[mm_input_id]
+ if mm_feature.data is None:
+ continue
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
@@ -1117,7 +1136,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> ModelRunnerOutput:
if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used.
- return None # noqa
+ return None # type: ignore[return-value]
scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None
@@ -1254,15 +1273,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
- valid_sampled_token_ids: list[np.ndarray] = [
- row for row in selected_token_ids.numpy()
- ]
+ valid_sampled_token_ids = selected_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for i in discard_sampled_tokens_req_indices:
- valid_sampled_token_ids[i] = np.array([])
+ valid_sampled_token_ids[i].clear()
# Append sampled tokens
for i, req_state, seq_len in request_seq_lens:
@@ -1275,7 +1292,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [
- seq.numpy() for seq in selected_token_ids[valid_mask].split(gen_lens)
+ seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
@@ -1699,7 +1716,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> None:
# Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs:
- if self.model_config.multimodal_config.skip_mm_profiling:
+ mm_config = self.model_config.multimodal_config
+ if mm_config is not None and mm_config.skip_mm_profiling:
logger.info(
"Skipping memory profiling for multimodal encoder and "
"encoder cache."
@@ -2169,5 +2187,9 @@ def replace_set_lora(model):
if isinstance(module, BaseLayerWithLoRA):
module._original_set_lora = module.set_lora
module._original_reset_lora = module.reset_lora
- module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
- module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__)
+ module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign]
+ module, module.__class__
+ )
+ module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign]
+ module, module.__class__
+ )
diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py
index a716a9c3aa822..e1a109eca0a88 100644
--- a/vllm/v1/worker/tpu_worker.py
+++ b/vllm/v1/worker/tpu_worker.py
@@ -106,9 +106,6 @@ class TPUWorker:
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
)
- if self.model_config.seed is None:
- self.model_config.seed = 0
-
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@@ -141,8 +138,7 @@ class TPUWorker:
# Set random seed.
set_random_seed(self.model_config.seed)
- if self.model_config.seed is not None:
- xm.set_rng_state(self.model_config.seed, self.device)
+ xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
@@ -332,7 +328,7 @@ class TPUWorker:
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
- distributed_init_method=distributed_init_method,
+ distributed_init_method=distributed_init_method or "env://",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py
index 095407a8b9596..92e4ce3abdba3 100644
--- a/vllm/v1/worker/utils.py
+++ b/vllm/v1/worker/utils.py
@@ -280,7 +280,7 @@ def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
- num_attn_module: int | None = 1,
+ num_attn_module: int = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
@@ -316,7 +316,7 @@ def bind_kv_cache(
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
- if current_platform.is_cuda() or current_platform.is_xpu():
+ if current_platform.is_cuda_alike() or current_platform.is_xpu():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
@@ -362,5 +362,7 @@ def is_residual_scattered_for_sp(
or vllm_config.compilation_config.use_inductor_graph_partition
):
return True
-
- return num_input_tokens in vllm_config.compilation_config.compile_sizes
+ compile_sizes = vllm_config.compilation_config.compile_sizes
+ if compile_sizes is None:
+ return False
+ return num_input_tokens in compile_sizes
diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py
index 16f321c080779..57e7037e946ec 100644
--- a/vllm/v1/worker/worker_base.py
+++ b/vllm/v1/worker/worker_base.py
@@ -315,10 +315,12 @@ class WorkerWrapperBase:
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank]
+ assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
+ assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
index 26c6f8d06bdcd..4d7864e90496a 100644
--- a/vllm/v1/worker/xpu_worker.py
+++ b/vllm/v1/worker/xpu_worker.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
+from typing import Any
import torch
import torch.distributed
@@ -37,6 +38,7 @@ class XPUWorker(Worker):
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
+ self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
@@ -148,7 +150,12 @@ class XPUWorker(Worker):
return int(available_kv_cache_memory)
def init_device(self):
- if self.device_config.device.type == "xpu" and current_platform.is_xpu():
+ device = self.device_config.device
+ if (
+ isinstance(device, torch.device)
+ and device.type == "xpu"
+ and current_platform.is_xpu()
+ ):
self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)