diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
index 734a817fd1a06..b571618f48c2b 100755
--- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
+++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
@@ -128,7 +128,7 @@ run_and_track_test() {
# --- Actual Test Execution ---
run_and_track_test 1 "test_struct_output_generate.py" \
- "HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
+ "python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
run_and_track_test 2 "test_moe_pallas.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
run_and_track_test 3 "test_lora.py" \
@@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
+run_and_track_test 7 "test_tpu_int8.py" \
+ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
index 9e7b5a546243c..d55a786e41e8b 100755
--- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
@@ -134,7 +134,7 @@ run_and_track_test 1 "test_compilation.py" \
run_and_track_test 2 "test_basic.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py"
run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \
- "HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
+ "python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
run_and_track_test 4 "test_quantization_accuracy.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py"
run_and_track_test 5 "examples/offline_inference/tpu.py" \
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index e139c6b30586e..ebcf51981ef33 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -57,9 +57,10 @@ steps:
- vllm/
- tests/mq_llm_engine
- tests/async_engine
- - tests/test_inputs
+ - tests/test_inputs.py
+ - tests/test_outputs.py
- tests/multimodal
- - tests/test_utils
+ - tests/utils_
- tests/worker
- tests/standalone_tests/lazy_imports.py
commands:
@@ -70,7 +71,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s multimodal
- - pytest -v -s test_utils.py # Utils
+ - pytest -v -s utils_ # Utils
- pytest -v -s worker # Worker
- label: Python-only Installation Test
@@ -426,7 +427,6 @@ steps:
- label: Tensorizer Test # 11min
mirror_hardwares: [amdexperimental]
- soft_fail: true
source_file_dependencies:
- vllm/model_executor/model_loader
- tests/tensorizer_loader
@@ -535,8 +535,6 @@ steps:
- vllm/
- tests/models/language
commands:
- # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
- pip freeze | grep -E 'torch'
- pytest -v -s models/language -m core_model
@@ -547,8 +545,10 @@ steps:
- vllm/
- tests/models/language/generation
commands:
- # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
+ # Install fast path packages for testing against transformers
+ # Note: also needed to run plamo2 model in vLLM
+ - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5'
+ - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2'
- pytest -v -s models/language/generation -m hybrid_model
- label: Language Models Test (Extended Generation) # 1hr20min
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 5bc944296763d..a0a327319a468 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -9,7 +9,7 @@
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
-/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
+/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
/vllm/multimodal @DarkLight1337 @ywang96
/vllm/vllm_flash_attn @LucasWilkinson
/vllm/lora @jeejeelee
@@ -20,7 +20,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
# Any change to the VllmConfig changes can have a large user-facing impact,
# so spam a lot of people
-/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor
+/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
# vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
@@ -34,16 +34,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
-/tests/kernels @tlrmchlsmth @WoosukKwon
+/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256
/tests/models @DarkLight1337 @ywang96
/tests/multi_step @alexm-redhat @comaniac
/tests/multimodal @DarkLight1337 @ywang96
/tests/prefix_caching @comaniac @KuntaiDu
-/tests/quantization @mgoin @robertgshaw2-redhat
+/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256
/tests/test_inputs.py @DarkLight1337 @ywang96
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
/tests/v1/structured_output @mgoin @russellb @aarnphm
-/tests/weight_loading @mgoin @youkaichao
+/tests/weight_loading @mgoin @youkaichao @yewentao256
/tests/lora @jeejeelee
# Docs
diff --git a/.github/mergify.yml b/.github/mergify.yml
index d8ae509e0ac30..495d207d44260 100644
--- a/.github/mergify.yml
+++ b/.github/mergify.yml
@@ -118,6 +118,20 @@ pull_request_rules:
add:
- qwen
+- name: label-gpt-oss
+ description: Automatically apply gpt-oss label
+ conditions:
+ - or:
+ - files~=^examples/.*gpt[-_]?oss.*\.py
+ - files~=^tests/.*gpt[-_]?oss.*\.py
+ - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py
+ - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py
+ - title~=(?i)gpt[-_]?oss
+ actions:
+ label:
+ add:
+ - gpt-oss
+
- name: label-rocm
description: Automatically apply rocm label
conditions:
diff --git a/.gitignore b/.gitignore
index 96b97a552c540..721dd7536bec2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,9 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
+# triton jit
+.triton
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -147,7 +150,8 @@ venv.bak/
# mkdocs documentation
/site
docs/argparse
-docs/examples
+docs/examples/*
+!docs/examples/README.md
# mypy
.mypy_cache/
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7625590e64287..43b9d9b6418a1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -428,6 +428,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
+ "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
diff --git a/README.md b/README.md
index 5348405b72d2c..fd8b02ac1f781 100644
--- a/README.md
+++ b/README.md
@@ -18,14 +18,15 @@ Easy, fast, and cheap LLM serving for everyone
*Latest News* 🔥
+- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
-- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News
+- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
@@ -121,6 +122,7 @@ Cash Donations:
Compute Resources:
+- Alibaba Cloud
- AMD
- Anyscale
- AWS
@@ -160,7 +162,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
## Contact Us
-- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
+- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index c7229dbb8e90d..1559ca2d92841 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -31,7 +31,7 @@ class RequestFuncInput:
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
- multi_modal_content: Optional[dict] = None
+ multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False
language: Optional[str] = None
@@ -364,7 +364,15 @@ async def async_request_openai_chat_completions(
) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
- content.append(request_func_input.multi_modal_content)
+ mm_content = request_func_input.multi_modal_content
+ if isinstance(mm_content, list):
+ content.extend(mm_content)
+ elif isinstance(mm_content, dict):
+ content.append(mm_content)
+ else:
+ raise TypeError(
+ "multi_modal_content must be a dict or list[dict] for openai-chat"
+ )
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
@@ -491,7 +499,10 @@ async def async_request_openai_audio(
buffer.seek(0)
return buffer
- with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
+ mm_audio = request_func_input.multi_modal_content
+ if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
+ raise TypeError("multi_modal_content must be a dict containing 'audio'")
+ with to_bytes(*mm_audio["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 1ad6cef7a9dbc..ea684f18a7421 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -52,7 +52,7 @@ class SampleRequest:
prompt: Union[str, Any]
prompt_len: int
expected_output_len: int
- multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
+ multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
lora_request: Optional[LoRARequest] = None
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index 93b72211eb332..ae38caf7290b1 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -263,7 +263,14 @@ async def benchmark(
input_requests[0].multi_modal_data,
)
- assert test_mm_content is None or isinstance(test_mm_content, dict)
+ assert (
+ test_mm_content is None
+ or isinstance(test_mm_content, dict)
+ or (
+ isinstance(test_mm_content, list)
+ and all(isinstance(item, dict) for item in test_mm_content)
+ )
+ ), "multi_modal_data must be a dict or list[dict]"
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 72250e2fb6d2b..13bf1be836f6a 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -22,10 +22,10 @@ from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
-def ensure_divisibility(numerator, denominator):
+def ensure_divisibility(numerator, denominator, text):
"""Ensure that numerator is divisible by the denominator."""
- assert numerator % denominator == 0, (
- "intermediate_size {} is not divisible by tp {}.".format(numerator, denominator)
+ assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
+ text, numerator, denominator
)
@@ -577,12 +577,10 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
- shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
- shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in (
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
@@ -591,17 +589,14 @@ def main(args: argparse.Namespace):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
- shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
- shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts
topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0]
- shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Support for llama4
config = config.get_text_config()
@@ -609,8 +604,14 @@ def main(args: argparse.Namespace):
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
+ enable_ep = bool(args.enable_expert_parallel)
+ if enable_ep:
+ ensure_divisibility(E, args.tp_size, "Number of experts")
+ E = E // args.tp_size
+ shard_intermediate_size = 2 * intermediate_size
+ else:
+ ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size
- ensure_divisibility(intermediate_size, args.tp_size)
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
@@ -742,6 +743,7 @@ if __name__ == "__main__":
parser.add_argument(
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
)
+ parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
)
diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py
new file mode 100644
index 0000000000000..b9147361708fd
--- /dev/null
+++ b/benchmarks/kernels/benchmark_mrope.py
@@ -0,0 +1,328 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
+# It generates test data, runs benchmarks, and saves results to a CSV file.
+#
+# The CSV file (named with current date/time) contains these columns:
+# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
+# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
+# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
+# speedup
+#
+# == Usage Examples ==
+#
+# Single model benchmark:
+# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
+# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
+#
+# All models benchmark:
+# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
+# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
+#
+# All models with different TP sizes:
+# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
+# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
+#
+# All models with different token counts:
+# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
+# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
+import csv
+import os
+import time
+from datetime import datetime
+from typing import Any
+
+import numpy as np
+import torch
+
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.platforms import current_platform
+from vllm.transformers_utils.config import get_config
+from vllm.utils import FlexibleArgumentParser
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def generate_test_data(
+ num_tokens: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ head_size: int,
+ max_position_embeddings: int,
+ dtype: torch.dtype,
+ device: torch.device,
+):
+ """Generate test data for given configuration."""
+ # Create 2D positions (3, num_tokens) for multimodal case
+ positions = torch.randint(
+ 0, max_position_embeddings // 4, (3, num_tokens), device=device
+ )
+
+ # Create query and key tensors
+ query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
+ key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
+
+ return positions, query, key
+
+
+def calculate_stats(times: list[float]) -> dict[str, float]:
+ """Calculate statistics from a list of times."""
+ times_array = np.array(times)
+ return {
+ "mean": np.mean(times_array),
+ "median": np.median(times_array),
+ "p99": np.percentile(times_array, 99),
+ "min": np.min(times_array),
+ "max": np.max(times_array),
+ }
+
+
+def benchmark_mrope(
+ model_name: str,
+ num_tokens: int,
+ head_dim: int,
+ tp_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 8192,
+ rope_theta: float = 10000,
+ is_neox_style: bool = True,
+ rope_scaling: dict[str, Any] = None,
+ dtype: torch.dtype = torch.bfloat16,
+ seed: int = 0,
+ warmup_iter: int = 10,
+ benchmark_iter: int = 100,
+ csv_writer=None,
+):
+ current_platform.seed_everything(seed)
+ torch.set_default_device(device)
+ # the parameters to compute the q k v size based on tp_size
+ mrope_helper_class = get_rope(
+ head_size=head_dim,
+ rotary_dim=head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ is_neox_style=is_neox_style,
+ rope_scaling=rope_scaling,
+ dtype=dtype,
+ ).to(device=device)
+
+ print(80 * "=")
+ print(
+ f"Evaluating model: {model_name} "
+ f"with tp_size: {tp_size} "
+ f"and num_tokens: {num_tokens}, "
+ f"dtype: {dtype}"
+ )
+
+ # create q k v input tensors
+ # create rotary pos emb input tensors
+ positions, query, key = generate_test_data(
+ num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
+ )
+
+ # Warm up
+ for _ in range(warmup_iter):
+ mrope_helper_class.forward_native(
+ positions,
+ query.clone(),
+ key.clone(),
+ )
+
+ mrope_helper_class.forward_cuda(
+ positions,
+ query.clone(),
+ key.clone(),
+ )
+
+ torch.cuda.synchronize()
+
+ # Time reference implementation
+ torch_times = []
+ for _ in range(benchmark_iter):
+ query_clone = query.clone()
+ key_clone = key.clone()
+ torch.cuda.synchronize()
+ start_time = time.time()
+
+ mrope_helper_class.forward_native(
+ positions,
+ query_clone,
+ key_clone,
+ )
+
+ torch.cuda.synchronize()
+ torch_times.append(time.time() - start_time)
+
+ # Time triton kernel implementation
+ triton_times = []
+ for _ in range(benchmark_iter):
+ query_clone = query.clone()
+ key_clone = key.clone()
+ torch.cuda.synchronize()
+ start_time = time.time()
+ mrope_helper_class.forward_cuda(
+ positions,
+ query_clone,
+ key_clone,
+ )
+ torch.cuda.synchronize()
+ triton_times.append(time.time() - start_time)
+
+ # Calculate statistics
+ torch_stats = calculate_stats(torch_times)
+ triton_stats = calculate_stats(triton_times)
+ print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
+
+ print(
+ f"Torch implementation: "
+ f"mean={torch_stats['mean']:.8f}s, "
+ f"median={torch_stats['median']:.8f}s, "
+ f"p99={torch_stats['p99']:.8f}s"
+ )
+
+ print(
+ f"Triton implementation: "
+ f"mean={triton_stats['mean']:.8f}s, "
+ f"median={triton_stats['median']:.8f}s, "
+ f"p99={triton_stats['p99']:.8f}s"
+ )
+
+ print(
+ f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
+ )
+
+ # Write to CSV
+ if csv_writer:
+ row = [
+ model_name,
+ tp_size,
+ num_tokens,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ max_position,
+ rope_theta,
+ is_neox_style,
+ str(rope_scaling),
+ str(dtype).split(".")[-1],
+ torch_stats["mean"],
+ torch_stats["median"],
+ torch_stats["p99"],
+ torch_stats["min"],
+ torch_stats["max"],
+ triton_stats["mean"],
+ triton_stats["median"],
+ triton_stats["p99"],
+ triton_stats["min"],
+ triton_stats["max"],
+ torch_stats["mean"] / triton_stats["mean"], # speedup
+ ]
+ csv_writer.writerow(row)
+
+ return torch_stats, triton_stats
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(
+ description="Benchmark the rotary embedding kernels."
+ )
+ parser.add_argument("--model-name", type=str, default="")
+ parser.add_argument("--tp-size", type=int, default=1)
+ parser.add_argument("--warmup-iter", type=int, default=10)
+ parser.add_argument("--benchmark-iter", type=int, default=100)
+ parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
+ parser.add_argument("--trust-remote-code", action="store_true")
+ parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
+ args = parser.parse_args()
+ print(args)
+
+ # Create CSV file for results
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
+
+ with open(csv_filename, "w", newline="") as csvfile:
+ csv_writer = csv.writer(csvfile)
+ # Write header
+ header = [
+ "model_name",
+ "tp_size",
+ "num_tokens",
+ "num_heads",
+ "num_kv_heads",
+ "head_dim",
+ "max_position",
+ "rope_theta",
+ "is_neox_style",
+ "rope_scaling",
+ "dtype",
+ "torch_mean",
+ "torch_median",
+ "torch_p99",
+ "torch_min",
+ "torch_max",
+ "triton_mean",
+ "triton_median",
+ "triton_p99",
+ "triton_min",
+ "triton_max",
+ "speedup",
+ ]
+ csv_writer.writerow(header)
+
+ model_tp_dict = {}
+ if args.model_name == "":
+ model_tp_dict = {
+ "Qwen/Qwen2-VL-2B-Instruct": [1],
+ "Qwen/Qwen2-VL-7B-Instruct": [1],
+ "Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
+ "Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
+ "Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
+ "Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
+ }
+ else:
+ model_tp_dict[args.model_name] = [args.tp_size]
+
+ if args.num_tokens is None:
+ num_tokens_list = [2**i for i in range(0, 18)]
+ else:
+ num_tokens_list = args.num_tokens
+
+ for model_name, tp_list in model_tp_dict.items():
+ config = get_config(model_name, trust_remote_code=args.trust_remote_code)
+ for tp_size in tp_list:
+ # get the model config
+ total_num_kv_heads = config.num_key_value_heads
+ total_num_heads = config.num_attention_heads
+ num_heads = total_num_heads // tp_size
+ num_kv_heads = max(1, total_num_kv_heads // tp_size)
+ head_dim = config.hidden_size // total_num_heads
+ q_size = num_heads * head_dim
+ kv_size = num_kv_heads * head_dim
+ is_neox_style = True
+ rope_theta = config.rope_theta
+ max_position = config.max_position_embeddings
+
+ for num_tokens in num_tokens_list:
+ benchmark_mrope(
+ model_name=model_name,
+ num_tokens=num_tokens,
+ head_dim=head_dim,
+ tp_size=tp_size,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ max_position=max_position,
+ rope_theta=rope_theta,
+ is_neox_style=is_neox_style,
+ rope_scaling=config.rope_scaling,
+ dtype=getattr(torch, args.dtype),
+ seed=args.seed,
+ warmup_iter=args.warmup_iter,
+ benchmark_iter=args.benchmark_iter,
+ csv_writer=csv_writer,
+ )
+
+ print(f"Benchmark results saved to {csv_filename}")
diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md
new file mode 100644
index 0000000000000..ae0866ae60751
--- /dev/null
+++ b/benchmarks/multi_turn/README.md
@@ -0,0 +1,71 @@
+# Benchmark KV Cache Offloading with Multi-Turn Conversations
+
+The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt`
+
+First start serving your model
+
+```bash
+export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
+
+vllm serve $MODEL_NAME --disable-log-requests
+```
+
+## Synthetic Multi-Turn Conversations
+
+Download the following text file (used for generation of synthetic conversations)
+
+```bash
+wget https://www.gutenberg.org/ebooks/1184.txt.utf-8
+mv 1184.txt.utf-8 pg1184.txt
+```
+
+The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`).
+
+But you may use other text files if you prefer (using this specific file is not required).
+
+Then run the benchmarking script
+
+```bash
+export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
+
+python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
+--num-clients 2 --max-active-conversations 6
+```
+
+You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
+
+If successful, you will see the following output
+
+```bash
+----------------------------------------------------------------------------------------------------
+Statistics summary:
+runtime_sec = 215.810
+requests_per_sec = 0.769
+----------------------------------------------------------------------------------------------------
+ count mean std min 25% 50% 75% 90% 99% max
+ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54
+tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05
+latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94
+input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00
+input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00
+output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00
+output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00
+----------------------------------------------------------------------------------------------------
+```
+
+## ShareGPT Conversations
+
+To run with the ShareGPT data, download the following ShareGPT dataset:
+`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json`
+
+Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py`
+
+```bash
+python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128
+```
+
+The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles.
+
+The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed).
+
+Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`.
diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py
new file mode 100644
index 0000000000000..411b89dd23dc6
--- /dev/null
+++ b/benchmarks/multi_turn/bench_dataset.py
@@ -0,0 +1,493 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from abc import ABC, abstractmethod
+from statistics import mean
+from typing import Any, NamedTuple, Optional, Union
+
+import numpy as np # type: ignore
+import pandas as pd # type: ignore
+from bench_utils import (
+ TEXT_SEPARATOR,
+ Color,
+ logger,
+)
+from transformers import AutoTokenizer # type: ignore
+
+# Conversation ID is a string (e.g: "UzTK34D")
+ConvId = str
+
+# A list of dicts (dicts with keys "id" and "messages")
+ShareGptConversations = list[dict[str, Any]]
+
+# A list of dicts (dicts with keys "role" and "content")
+MessagesList = list[dict[str, str]]
+
+# Map conversation ID to conversation messages
+ConversationsMap = list[ConvId, MessagesList]
+
+
+class Distribution(ABC):
+ @abstractmethod
+ def sample(self, size: int = 1) -> np.ndarray:
+ pass
+
+
+class UniformDistribution(Distribution):
+ def __init__(
+ self,
+ min_val: Union[int, float],
+ max_val: Union[int, float],
+ is_integer: bool = True,
+ ) -> None:
+ self.min_val = min_val
+ self.max_val = max_val
+ self.is_integer = is_integer
+
+ def sample(self, size: int = 1) -> np.ndarray:
+ if self.is_integer:
+ return np.random.randint(
+ int(self.min_val), int(self.max_val + 1), size=size
+ )
+ else:
+ return np.random.uniform(self.min_val, self.max_val, size=size)
+
+ def __repr__(self) -> str:
+ return f"UniformDistribution[{self.min_val}, {self.max_val}]"
+
+
+class ConstantDistribution(Distribution):
+ def __init__(self, value: Union[int, float]) -> None:
+ self.value = value
+ self.max_val = value
+
+ def sample(self, size: int = 1) -> np.ndarray:
+ return np.full(shape=size, fill_value=self.value)
+
+ def __repr__(self) -> str:
+ return f"Constant[{self.value}]"
+
+
+class ZipfDistribution(Distribution):
+ def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
+ self.alpha = alpha
+ self.max_val = max_val
+
+ def sample(self, size: int = 1) -> np.ndarray:
+ samples = np.random.zipf(self.alpha, size=size)
+ if self.max_val:
+ samples = np.minimum(samples, self.max_val)
+ return samples
+
+ def __repr__(self) -> str:
+ return f"ZipfDistribution[{self.alpha}]"
+
+
+class PoissonDistribution(Distribution):
+ def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
+ self.alpha = alpha
+ self.max_val = max_val
+
+ def sample(self, size: int = 1) -> np.ndarray:
+ samples = np.random.poisson(self.alpha, size=size)
+ if self.max_val:
+ samples = np.minimum(samples, self.max_val)
+ return samples
+
+ def __repr__(self) -> str:
+ return f"PoissonDistribution[{self.alpha}]"
+
+
+class LognormalDistribution(Distribution):
+ def __init__(
+ self, mean: float, sigma: float, max_val: Optional[int] = None
+ ) -> None:
+ self.mean = mean
+ self.sigma = sigma
+ self.max_val = max_val
+
+ def sample(self, size: int = 1) -> np.ndarray:
+ samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
+ if self.max_val:
+ samples = np.minimum(samples, self.max_val)
+
+ return np.round(samples).astype(int)
+
+ def __repr__(self) -> str:
+ return f"LognormalDistribution[{self.mean}, {self.sigma}]"
+
+
+class GenConvArgs(NamedTuple):
+ num_conversations: int
+ text_files: list[str]
+ input_num_turns: Distribution
+ input_common_prefix_num_tokens: Distribution
+ input_prefix_num_tokens: Distribution
+ input_num_tokens: Distribution
+ output_num_tokens: Distribution
+ print_stats: bool
+
+
+def verify_field_exists(
+ conf: dict, field_name: str, section: str, subsection: str
+) -> None:
+ if field_name not in conf:
+ raise ValueError(
+ f"Missing field '{field_name}' in {section=} and {subsection=}"
+ )
+
+
+def get_random_distribution(
+ conf: dict, section: str, subsection: str, optional: bool = False
+) -> Distribution:
+ # section can be "prompt_input" or "prompt_output" (both required)
+ conf = conf[section]
+
+ if optional and subsection not in conf:
+ # Optional subsection, if not found assume the value is always 0
+ return ConstantDistribution(0)
+
+ # subsection can be "num_turns", "num_tokens" or "prefix_num_tokens"
+ if subsection not in conf:
+ raise ValueError(f"Missing subsection {subsection} in section {section}")
+
+ conf = conf[subsection]
+
+ distribution = conf.get("distribution")
+ if distribution is None:
+ raise ValueError(
+ f"Missing field 'distribution' in {section=} and {subsection=}"
+ )
+
+ if distribution == "constant":
+ verify_field_exists(conf, "value", section, subsection)
+ return ConstantDistribution(conf["value"])
+
+ elif distribution == "zipf":
+ verify_field_exists(conf, "alpha", section, subsection)
+ max_val = conf.get("max", None)
+ return ZipfDistribution(conf["alpha"], max_val=max_val)
+
+ elif distribution == "poisson":
+ verify_field_exists(conf, "alpha", section, subsection)
+ max_val = conf.get("max", None)
+ return PoissonDistribution(conf["alpha"], max_val=max_val)
+
+ elif distribution == "lognormal":
+ verify_field_exists(conf, "mean", section, subsection)
+ verify_field_exists(conf, "sigma", section, subsection)
+ max_val = conf.get("max", None)
+ return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val)
+
+ elif distribution == "uniform":
+ verify_field_exists(conf, "min", section, subsection)
+ verify_field_exists(conf, "max", section, subsection)
+
+ min_value = conf["min"]
+ max_value = conf["max"]
+
+ assert min_value > 0
+ assert min_value <= max_value
+
+ is_integer = isinstance(min_value, int) and isinstance(max_value, int)
+ return UniformDistribution(min_value, max_value, is_integer)
+ else:
+ raise ValueError(f"Unknown distribution: {distribution}")
+
+
+def parse_input_json_file(conf: dict) -> GenConvArgs:
+ # Validate the input file
+ assert isinstance(conf, dict)
+ required_fields = [
+ "filetype",
+ "num_conversations",
+ "text_files",
+ "prompt_input",
+ "prompt_output",
+ ]
+ for field in required_fields:
+ assert field in conf, f"Missing field {field} in input {conf}"
+
+ assert conf["filetype"] == "generate_conversations"
+
+ assert conf["num_conversations"] > 0, "num_conversations should be larger than zero"
+
+ text_files = conf["text_files"]
+
+ assert isinstance(text_files, list), "Field 'text_files' should be a list"
+ assert len(text_files) > 0, (
+ "Field 'text_files' should be a list with at least one file"
+ )
+
+ # Parse the parameters for the prompt input/output workload
+ input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns")
+ input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens")
+ input_common_prefix_num_tokens = get_random_distribution(
+ conf, "prompt_input", "common_prefix_num_tokens", optional=True
+ )
+ input_prefix_num_tokens = get_random_distribution(
+ conf, "prompt_input", "prefix_num_tokens"
+ )
+ output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens")
+
+ print_stats: bool = conf.get("print_stats", False)
+ assert isinstance(print_stats, bool), (
+ "Field 'print_stats' should be either 'true' or 'false'"
+ )
+
+ args = GenConvArgs(
+ num_conversations=conf["num_conversations"],
+ text_files=text_files,
+ input_num_turns=input_num_turns,
+ input_common_prefix_num_tokens=input_common_prefix_num_tokens,
+ input_prefix_num_tokens=input_prefix_num_tokens,
+ input_num_tokens=input_num_tokens,
+ output_num_tokens=output_num_tokens,
+ print_stats=print_stats,
+ )
+ return args
+
+
+def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None:
+ # Collect statistics
+ conv_stats: list[dict[Any, Any]] = []
+ req_stats: list[int] = []
+
+ print("\nCollecting statistics...")
+ for messages in conversations.values():
+ # messages is a list of dicts
+ user_tokens: list[int] = []
+ assistant_tokens: list[int] = []
+ request_tokens: list[int] = []
+
+ req_tokens = 0
+ for m in messages:
+ content = m["content"]
+ num_tokens = len(tokenizer(content).input_ids)
+
+ if m["role"] == "user":
+ user_tokens.append(num_tokens)
+ # New user prompt including all chat history
+ req_tokens += num_tokens
+ request_tokens.append(req_tokens)
+
+ elif m["role"] == "assistant":
+ assistant_tokens.append(num_tokens)
+ # Update assistant answer
+ # (will be part of chat history for the next user prompt)
+ req_tokens += num_tokens
+
+ item_stats = {
+ "conversation_turns": len(messages),
+ "user_tokens": mean(user_tokens),
+ "assistant_tokens": mean(assistant_tokens),
+ }
+
+ conv_stats.append(item_stats)
+ req_stats.extend(request_tokens)
+
+ # Print statistics
+ percentiles = [0.25, 0.5, 0.75, 0.9, 0.99]
+
+ print(TEXT_SEPARATOR)
+ print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
+ print(TEXT_SEPARATOR)
+ df = pd.DataFrame(conv_stats)
+ print(df.describe(percentiles=percentiles).transpose())
+ print(TEXT_SEPARATOR)
+ print(f"{Color.YELLOW}Request statistics:{Color.RESET}")
+ print(TEXT_SEPARATOR)
+ df = pd.DataFrame(req_stats, columns=["request_tokens"])
+ print(df.describe(percentiles=percentiles).transpose())
+ print(TEXT_SEPARATOR)
+
+
+def generate_conversations(
+ args: GenConvArgs, tokenizer: AutoTokenizer
+) -> ConversationsMap:
+ # Text for all user prompts
+ # (text from the input text files will be appended to this line)
+ base_prompt_text = "Please rewrite the following text and add more content: "
+ base_prompt_token_count = len(
+ tokenizer.encode(base_prompt_text, add_special_tokens=False)
+ )
+
+ logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}")
+ logger.info(args)
+
+ list_of_tokens = []
+
+ for filename in args.text_files:
+ # Load text file that will be used to generate prompts
+ with open(filename) as file:
+ data = file.read()
+ tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
+ list_of_tokens.extend(tokens_in_file)
+
+ conversations: ConversationsMap = {}
+ conv_id = 0
+
+ # Generate number of turns for every conversation
+ turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations)
+
+ # Turn count should be at least 2 (one user prompt and one assistant answer)
+ turn_count = np.maximum(turn_count, 2)
+
+ # Round up to an even number (every user prompt should have an answer)
+ turn_count = turn_count + (turn_count % 2)
+
+ # Generate number of prefix tokens for every conversation
+ conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample(
+ args.num_conversations
+ )
+
+ # Used to reduce shared text between conversations
+ # (jump/skip over text sections between conversations)
+ base_offset = 0
+
+ # Common prefix size for all conversations (only 1 sample required)
+ common_prefix_text = ""
+ common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0]
+ if common_prefix_tokens > 0:
+ # Using "." at the end to separate sentences
+ common_prefix_text = (
+ tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "."
+ )
+ base_offset += common_prefix_tokens
+
+ for conv_id in range(args.num_conversations):
+ # Generate a single conversation
+ messages: MessagesList = []
+
+ nturns = turn_count[conv_id]
+
+ # User prompt token count per turn (with lower limit)
+ input_token_count: np.ndarray = args.input_num_tokens.sample(nturns)
+ input_token_count = np.maximum(input_token_count, base_prompt_token_count)
+
+ # Assistant answer token count per turn (with lower limit)
+ output_token_count: np.ndarray = args.output_num_tokens.sample(nturns)
+ output_token_count = np.maximum(output_token_count, 1)
+
+ user_turn = True
+ for turn_id in range(nturns):
+ if user_turn:
+ role = "user"
+ num_tokens = input_token_count[turn_id]
+
+ # Generate the user prompt,
+ # use a unique prefix (the conv_id) for each conversation
+ # (to avoid shared prefix between conversations)
+ content = f"{conv_id} is a nice number... "
+
+ if len(common_prefix_text) > 0 and turn_id == 0:
+ content = common_prefix_text + content
+
+ # Update the number of tokens left for the content
+ num_tokens -= len(tokenizer.encode(content, add_special_tokens=False))
+
+ if turn_id == 0:
+ prefix_num_tokens = conv_prefix_tokens[conv_id]
+ if prefix_num_tokens > 0:
+ # Add prefix text (context) to the first turn
+ start_offset = base_offset
+ end_offset = start_offset + prefix_num_tokens
+ assert len(list_of_tokens) > end_offset, (
+ "Not enough input text to generate "
+ f"{prefix_num_tokens} tokens for the "
+ f"prefix text ({start_offset=}, {end_offset=})"
+ )
+
+ content += f"{conv_id}, " + tokenizer.decode(
+ list_of_tokens[start_offset:end_offset]
+ )
+ base_offset += prefix_num_tokens
+
+ # Add the actual user prompt/question after the prefix text
+ content += base_prompt_text
+ num_tokens -= base_prompt_token_count
+
+ if num_tokens > 0:
+ # Add text from the input file (to reach the desired token count)
+ start_offset = base_offset + turn_id * input_token_count.max()
+ end_offset = start_offset + num_tokens
+ assert len(list_of_tokens) > end_offset, (
+ f"Not enough input text to generate {num_tokens} tokens "
+ f"for the prompt ({start_offset=}, {end_offset=})"
+ )
+
+ # Convert tokens back to text
+ content += tokenizer.decode(list_of_tokens[start_offset:end_offset])
+ else:
+ role = "assistant"
+ # This content will not be used as input to the LLM server
+ # (actual answers will be used instead).
+ # Content is only required to determine the min_tokens/max_tokens
+ # (inputs to the LLM server).
+ num_tokens = output_token_count[turn_id]
+ assert len(list_of_tokens) > num_tokens, (
+ f"Not enough input text to generate {num_tokens} "
+ "tokens for assistant content"
+ )
+ content = tokenizer.decode(list_of_tokens[:num_tokens])
+
+ # Append the user/assistant message to the list of messages
+ messages.append({"role": role, "content": content})
+ user_turn = not user_turn
+
+ # Add the new conversation
+ conversations[f"CONV_ID_{conv_id}"] = messages
+
+ # Increase base offset for the next conversation
+ base_offset += nturns
+
+ if args.print_stats:
+ print_conv_stats(conversations, tokenizer)
+
+ return conversations
+
+
+def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap:
+ conversations: ConversationsMap = {}
+
+ for item in input_list:
+ conv_id: str = item["id"]
+ assert isinstance(conv_id, str)
+
+ assert conv_id not in conversations, (
+ f"Conversation ID {conv_id} found more than once in the input"
+ )
+
+ messages: MessagesList = item["messages"]
+ assert isinstance(messages, list), (
+ f"Conversation messages should be a list (ID: {conv_id})"
+ )
+ assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})"
+
+ conversations[conv_id] = messages
+
+ logger.info(f"Using {len(conversations)} unique conversations (IDs)")
+ assert len(conversations) == len(input_list)
+
+ # Print statistics about the selected conversations
+ stats: list[dict[str, Any]] = []
+ for conv_data in conversations.values():
+ stats.append({"num_turns": len(conv_data)})
+
+ print(TEXT_SEPARATOR)
+ print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
+ print(TEXT_SEPARATOR)
+ percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
+ conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles)
+ print(conv_stats.transpose())
+ print(TEXT_SEPARATOR)
+
+ return conversations
+
+
+def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations:
+ output: ShareGptConversations = []
+ for conv_id, conv_data in input_dict.items():
+ new_item = {"id": conv_id, "messages": conv_data}
+ output.append(new_item)
+
+ return output
diff --git a/benchmarks/multi_turn/bench_utils.py b/benchmarks/multi_turn/bench_utils.py
new file mode 100644
index 0000000000000..e959a4be711c9
--- /dev/null
+++ b/benchmarks/multi_turn/bench_utils.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import logging
+from enum import Enum
+
+
+class Color(Enum):
+ RED = "\033[91m"
+ GREEN = "\033[92m"
+ BLUE = "\033[94m"
+ PURPLE = "\033[95m"
+ CYAN = "\033[96m"
+ YELLOW = "\033[93m"
+ RESET = "\033[0m"
+
+ def __str__(self):
+ return self.value
+
+
+TEXT_SEPARATOR = "-" * 100
+
+# Configure the logger
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(levelname)s] - %(message)s",
+ datefmt="%d-%m-%Y %H:%M:%S",
+)
+logger = logging.getLogger(__name__)
diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
new file mode 100644
index 0000000000000..53c3207491d18
--- /dev/null
+++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
@@ -0,0 +1,1557 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import argparse
+import asyncio
+import json
+import logging
+import multiprocessing as mp
+import os
+import random
+import time
+from collections import Counter, deque
+from datetime import datetime
+from enum import Enum
+from http import HTTPStatus
+from statistics import mean
+from typing import NamedTuple, Optional, Union
+
+import aiohttp # type: ignore
+import numpy as np # type: ignore
+import pandas as pd # type: ignore
+from bench_dataset import (
+ ConversationsMap,
+ ConvId,
+ GenConvArgs,
+ MessagesList,
+ ShareGptConversations,
+ conversations_dict_to_list,
+ conversations_list_to_dict,
+ generate_conversations,
+ parse_input_json_file,
+)
+from bench_utils import TEXT_SEPARATOR, Color, logger
+from transformers import AutoTokenizer # type: ignore
+
+NUM_TOKENS_FROM_DATASET = 0
+TERM_SIGNAL = None
+
+
+class ConversationSampling(str, Enum):
+ ROUND_ROBIN = "round_robin"
+ RANDOM = "random"
+
+ def __str__(self):
+ return self.value
+
+
+class ClientArgs(NamedTuple):
+ seed: int
+ max_num_requests: Optional[int]
+ skip_first_turn: bool
+ max_turns: Optional[int]
+ max_active_conversations: int
+ verbose: bool
+ print_content: bool
+ verify_output: bool
+ conversation_sampling: ConversationSampling
+ request_rate: float
+
+
+class RequestArgs(NamedTuple):
+ chat_url: str
+ model: str
+ stream: bool
+ limit_min_tokens: int # Use negative value for no limit
+ limit_max_tokens: int # Use negative value for no limit
+
+
+class BenchmarkArgs(NamedTuple):
+ url: str
+ num_clients: int
+ early_stop: bool
+
+
+class ServerResponse(NamedTuple):
+ valid: bool
+ ttft_ms: float # time to first chunk
+ tpot_ms: float # time per output chunk (one or more tokens)
+ latency_ms: float
+ start_time_ms: float
+ first_chunk: str # first chunk of the content
+ content: str # includes the first_chunk
+ num_chunks: int
+
+ def __str__(self) -> str:
+ return f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}" # noqa: E501
+
+
+class RequestStats(NamedTuple):
+ ttft_ms: float
+ tpot_ms: float
+ latency_ms: float
+ start_time_ms: float
+ input_num_turns: int
+ input_num_tokens: int
+ output_num_tokens: int
+ output_num_chunks: int
+ output_num_first_chunk_tokens: int
+ approx_cached_percent: float
+ conversation_id: str
+ client_id: int
+
+ def __str__(self) -> str:
+ return (
+ f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}, input_num_tokens {self.input_num_tokens}, " # noqa: E501
+ f"output_num_tokens {self.output_num_tokens} ({self.output_num_chunks} chunks, {self.output_num_first_chunk_tokens} tokens in first chunk), " # noqa: E501
+ f"approx_cached_percent {self.approx_cached_percent:.2f}%"
+ )
+
+
+class MetricStats:
+ def __init__(self) -> None:
+ self.min: Optional[float] = None
+ self.max: Optional[float] = None
+ self.avg: Optional[float] = None
+ self.sum = 0.0
+ self.count = 0
+
+ def update(self, value: float) -> None:
+ if self.min is None:
+ self.min = value
+ else:
+ self.min = min(self.min, value)
+
+ if self.max is None:
+ self.max = value
+ else:
+ self.max = max(self.max, value)
+
+ self.sum += value
+ self.count += 1
+ self.avg = self.sum / self.count
+
+ def __repr__(self) -> str:
+ if self.count == 0:
+ return "no data"
+ return f"avg: {self.avg:>10.3f}, min: {self.min:>10.3f}, max: {self.max:>10.3f}"
+
+
+class MovingAverage:
+ def __init__(self, window_size: int) -> None:
+ self.window_size = window_size
+ self.window = np.zeros(window_size)
+ self.index = 0
+ self.sum = 0.0
+ self.count = 0
+ self.avg: Optional[float] = None
+
+ def update(self, new_value: float) -> None:
+ if self.count < self.window_size:
+ # Filling up the window
+ self.sum += new_value
+ self.window[self.count] = new_value
+ self.count += 1
+ else:
+ # Window is full, start replacing old values
+ old_value = self.window[self.index]
+ self.sum = self.sum - old_value + new_value
+ self.window[self.index] = new_value
+ self.index = (self.index + 1) % self.window_size
+
+ self.avg = self.sum / self.count
+
+ def __repr__(self) -> str:
+ if self.count == 0:
+ return "no data"
+ return f"avg: {self.avg:>10.3f} ({self.count} samples)"
+
+
+class DebugStats:
+ def __init__(self, logger: logging.Logger, window_size: int) -> None:
+ self.logger = logger
+ self.metrics: dict[str, Union[MovingAverage, MetricStats]] = {
+ "moving_avg_ttft_ms": MovingAverage(window_size),
+ "moving_avg_tpot_ms": MovingAverage(window_size),
+ "ttft_ms": MetricStats(),
+ "tpot_ms": MetricStats(),
+ "latency_ms": MetricStats(),
+ "input_num_turns": MetricStats(),
+ "input_num_tokens": MetricStats(),
+ "output_num_tokens": MetricStats(),
+ }
+
+ def update(self, data: RequestStats) -> None:
+ self.metrics["ttft_ms"].update(data.ttft_ms)
+ self.metrics["moving_avg_ttft_ms"].update(data.ttft_ms)
+ self.metrics["tpot_ms"].update(data.tpot_ms)
+ self.metrics["moving_avg_tpot_ms"].update(data.tpot_ms)
+ self.metrics["latency_ms"].update(data.latency_ms)
+ self.metrics["input_num_turns"].update(data.input_num_turns)
+ self.metrics["input_num_tokens"].update(data.input_num_tokens)
+ self.metrics["output_num_tokens"].update(data.output_num_tokens)
+
+ def print(self) -> None:
+ self.logger.info("-" * 50)
+ for k, v in self.metrics.items():
+ kv_info = f"[{k:25}] {v}"
+ self.logger.info(kv_info)
+ self.logger.info("-" * 50)
+
+
+# Must support Python 3.8, we can't use str.removeprefix(prefix)
+# introduced in Python 3.9
+def remove_prefix(text: str, prefix: str) -> str:
+ if text.startswith(prefix):
+ return text[len(prefix) :]
+ return text
+
+
+def nanosec_to_millisec(value: float) -> float:
+ return value / 1000000.0
+
+
+def nanosec_to_sec(value: float) -> float:
+ return value / 1000000000.0
+
+
+async def send_request(
+ session: aiohttp.ClientSession,
+ messages: list[dict[str, str]],
+ chat_url: str,
+ model: str,
+ stream: bool = True,
+ min_tokens: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+) -> ServerResponse:
+ payload = {
+ "model": model,
+ "messages": messages,
+ "seed": 0,
+ "temperature": 0.0,
+ }
+
+ if stream:
+ payload["stream"] = True
+ payload["stream_options"] = {"include_usage": False}
+
+ if min_tokens is not None:
+ payload["min_tokens"] = min_tokens
+
+ if max_tokens is not None:
+ payload["max_tokens"] = max_tokens
+
+ headers = {"Content-Type": "application/json"}
+
+ # Calculate the timeout for the request
+ timeout_sec = 120
+ if max_tokens is not None:
+ # Assume TPOT of 200ms and use max_tokens to determine timeout
+ timeout_sec = max(timeout_sec, int(max_tokens * 0.2))
+ timeout = aiohttp.ClientTimeout(total=timeout_sec)
+
+ valid_response = True
+ ttft: Optional[float] = None
+ chunk_delay: list[int] = []
+ latency: Optional[float] = None
+ first_chunk = ""
+ generated_text = ""
+
+ start_time: int = time.perf_counter_ns()
+ most_recent_timestamp: int = start_time
+
+ async with session.post(
+ url=chat_url, json=payload, headers=headers, timeout=timeout
+ ) as response:
+ http_status = HTTPStatus(response.status)
+ if http_status == HTTPStatus.OK:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
+ if chunk == "[DONE]":
+ # End of stream
+ latency = time.perf_counter_ns() - start_time
+ elif stream is False:
+ data = json.loads(chunk)
+ message = data["choices"][0]["message"]
+ assert message["role"] == "assistant"
+ generated_text += message["content"]
+ else:
+ timestamp: int = time.perf_counter_ns()
+ data = json.loads(chunk)
+
+ # Delta is the new content/text/data
+ delta = data["choices"][0]["delta"]
+ if delta.get("content", None):
+ if ttft is None:
+ # First token
+ first_token_time = time.perf_counter_ns()
+ ttft = first_token_time - start_time
+ first_chunk = delta["content"]
+ else:
+ # Decoding phase
+ chunk_delay.append(timestamp - most_recent_timestamp)
+
+ generated_text += delta["content"]
+
+ most_recent_timestamp = timestamp
+ else:
+ valid_response = False
+ content = await response.text()
+ logger.warning(
+ f"{Color.YELLOW}Received HTTP status {http_status.value} "
+ f"({http_status.phrase}): {content}{Color.RESET}"
+ )
+
+ if latency is None:
+ latency = -1.0
+ if valid_response:
+ # Streaming is disabled, latency was not set
+ latency = time.perf_counter_ns() - start_time
+
+ if ttft is None:
+ # The response was a single chunk
+ ttft = latency
+
+ # Each chunk may include more than one token
+ tpot: float = mean(chunk_delay) if len(chunk_delay) > 0 else 0.0
+ num_chunks: int = len(chunk_delay)
+
+ sr = ServerResponse(
+ valid=valid_response,
+ ttft_ms=nanosec_to_millisec(ttft) if ttft > 0.0 else -1.0,
+ tpot_ms=nanosec_to_millisec(tpot),
+ latency_ms=nanosec_to_millisec(latency),
+ start_time_ms=nanosec_to_millisec(start_time),
+ first_chunk=first_chunk,
+ content=generated_text,
+ num_chunks=num_chunks,
+ )
+ return sr
+
+
+def get_short_string(input: str) -> str:
+ n = 20
+ if len(input) < 400:
+ return input
+
+ return f"{input[:n]}...{input[-n:]}"
+
+
+def get_token_count(tokenizer: AutoTokenizer, text: str) -> int:
+ return len(tokenizer(text, add_special_tokens=False).input_ids)
+
+
+def get_messages_token_count(
+ tokenizer: AutoTokenizer, messages: list[dict[str, str]]
+) -> int:
+ token_count = 0
+ for m in messages:
+ token_count += get_token_count(tokenizer, m["content"])
+
+ return token_count
+
+
+async def send_turn(
+ session: aiohttp.ClientSession,
+ client_id: int,
+ conv_id: str,
+ conversation_messages: MessagesList,
+ messages_to_use: int,
+ tokenizer: AutoTokenizer,
+ req_args: RequestArgs,
+ verbose: bool,
+ verify_output: bool,
+) -> Optional[RequestStats]:
+ assert messages_to_use > 0
+ assert messages_to_use <= len(conversation_messages)
+
+ messages = conversation_messages[:messages_to_use]
+
+ # Index of the next message (the role should be "user")
+ index = messages_to_use - 1
+
+ # Verify that the message has only two keys, "role" and "content"
+ assert len(messages[index].keys()) == 2
+ assert "role" in messages[index] and "content" in messages[index]
+ assert messages[index]["role"] == "user", (
+ f"Failed on conversation ID {conv_id}, message role should be user"
+ )
+
+ if verbose:
+ print(
+ f"{Color.CYAN}Messages (conversation ID {conv_id},"
+ f" {len(messages)} turns):{Color.RESET}",
+ messages,
+ )
+
+ # None means that there is no upper/lower limit for the output token count
+ min_tokens = None if req_args.limit_min_tokens < 0 else req_args.limit_min_tokens
+ max_tokens = None if req_args.limit_max_tokens < 0 else req_args.limit_max_tokens
+
+ if len(conversation_messages) > messages_to_use:
+ # The conversation contains an assistant answer for the next user prompt
+ if (
+ min_tokens == NUM_TOKENS_FROM_DATASET
+ or max_tokens == NUM_TOKENS_FROM_DATASET
+ ):
+ # Compute number of tokens in the answer (from the input conversation)
+ assistant_answer = conversation_messages[messages_to_use]
+ answer_num_tokens = get_token_count(tokenizer, assistant_answer["content"])
+ assert assistant_answer["role"] == "assistant"
+
+ if min_tokens == NUM_TOKENS_FROM_DATASET:
+ min_tokens = max(1, answer_num_tokens)
+
+ if max_tokens == NUM_TOKENS_FROM_DATASET:
+ max_tokens = max(1, answer_num_tokens)
+
+ # Send the current conversation to LLM and get a response
+ response: ServerResponse = await send_request(
+ session,
+ messages,
+ req_args.chat_url,
+ req_args.model,
+ req_args.stream,
+ min_tokens,
+ max_tokens,
+ )
+
+ if response.valid is False:
+ # Request failed
+ return None
+
+ # Compute number of tokens in input / output
+ input_num_tokens = get_messages_token_count(tokenizer, messages)
+
+ # Num tokens in the user's last question
+ question_num_tokens = get_token_count(tokenizer, messages[index]["content"])
+
+ # Num tokens in the history/context of the question
+ assert input_num_tokens >= question_num_tokens
+ history_num_tokens = input_num_tokens - question_num_tokens
+
+ # Num tokens in the LLM's answer (first chunk and full answer)
+ first_chunk_tokens = get_token_count(tokenizer, response.first_chunk)
+
+ output_content = response.content
+ output_num_tokens = get_token_count(tokenizer, output_content)
+
+ # Prefix caching approximated cached percent
+ approx_cached_percent = (
+ 100.0 * (history_num_tokens / input_num_tokens) if input_num_tokens > 0 else 0.0
+ )
+
+ # Compute the correct TTFT and TPOT (based on tokens and not chunks).
+ # Required because multiple output tokens may be bundled in a single chunk.
+ if output_num_tokens > 1 and output_num_tokens > first_chunk_tokens:
+ # More than one token and more than one chunk in the output
+ decode_ms = response.latency_ms - response.ttft_ms
+ decode_num_tokens = output_num_tokens - first_chunk_tokens
+ tpot_ms = decode_ms / decode_num_tokens
+ else:
+ # In this case: output_num_tokens == first_chunk_tokens
+ # Output was a single chunk (output_num_tokens > 1)
+ # or even a single token (output_num_tokens == 1)
+ tpot_ms = 0.0
+
+ if first_chunk_tokens > 1:
+ # First chunk had multiple tokens, adjust TTFT for a single token
+ delta_ms = (first_chunk_tokens - 1) * tpot_ms
+ ttft_ms = max(0.1, response.ttft_ms - delta_ms)
+ else:
+ # First chunk had only one token
+ ttft_ms = response.ttft_ms
+
+ rs = RequestStats(
+ ttft_ms=ttft_ms,
+ tpot_ms=tpot_ms,
+ latency_ms=response.latency_ms,
+ start_time_ms=response.start_time_ms,
+ input_num_turns=len(messages),
+ input_num_tokens=input_num_tokens,
+ output_num_tokens=output_num_tokens,
+ output_num_chunks=response.num_chunks,
+ output_num_first_chunk_tokens=first_chunk_tokens,
+ approx_cached_percent=approx_cached_percent,
+ conversation_id=conv_id,
+ client_id=client_id,
+ )
+
+ if verbose:
+ print(
+ f"\n{Color.YELLOW}Response ({output_num_tokens} tokens):{Color.RESET}",
+ output_content,
+ )
+ print(f"{Color.YELLOW}Response metrics: {rs}{Color.RESET}")
+ print("-" * 70)
+
+ # Save the LLM's answer (will be used as part of the context for the next user turn)
+ answer_index = messages_to_use
+ if len(conversation_messages) > answer_index:
+ assert conversation_messages[answer_index]["role"] == "assistant", (
+ f"Failed on conversation ID {conv_id}, message role should be assistant"
+ )
+
+ orig_content = conversation_messages[answer_index]["content"]
+ if verify_output:
+ # Compare the new answer to the answer from the input file
+ debug_info = (
+ f"LLM/dataset answers do not match ({conv_id}):"
+ f"\n'{get_short_string(output_content)}' (len: {len(output_content)}),"
+ f"\n'{get_short_string(orig_content)}' (len: {len(orig_content)})"
+ )
+ if orig_content != output_content:
+ raise ValueError(debug_info)
+
+ # Update the answer
+ conversation_messages[answer_index]["content"] = output_content
+ else:
+ # A user prompt that has no answer, add the answer as a new message
+ new_answer = {"role": "assistant", "content": output_content}
+ conversation_messages.append(new_answer)
+
+ return rs
+
+
+async def poisson_sleep(request_rate: float, verbose: bool = False) -> None:
+ # Generate a random time interval from the Poisson distribution
+ assert request_rate > 0
+
+ interval = np.random.exponential(1.0 / request_rate)
+ if verbose:
+ logger.info(f"Sleeping for {interval:.3f} seconds...")
+ await asyncio.sleep(interval)
+
+
+async def client_main(
+ args: ClientArgs,
+ req_args: RequestArgs,
+ client_id: int,
+ tokenizer: AutoTokenizer,
+ stop_event: mp.Event, # type: ignore
+ task_queue: mp.Queue,
+ result_queue: mp.Queue,
+ conv_queue: mp.Queue,
+) -> None:
+ logger.info(
+ f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501
+ )
+
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ # Active conversations
+ active_convs: ConversationsMap = {}
+ conv_id_queue: deque = deque(maxlen=args.max_active_conversations)
+
+ # Keep track of how many messages have been used for each conversation
+ turns_count: Counter = Counter()
+ num_successes = 0
+ num_failures = 0
+
+ # Track the timestamp (time.perf_counter())
+ # of the last turn per conversation (only for debug)
+ time_of_last_turn: dict[ConvId, float] = {}
+
+ # Flag that indicates that there are no new tasks (conversations) for the client
+ task_queue_empty = False
+
+ async with aiohttp.ClientSession() as session:
+ # Print progress
+
+ while task_queue_empty is False:
+ result = None
+
+ if (
+ args.max_num_requests
+ and num_successes + num_failures == args.max_num_requests
+ ):
+ logger.info(
+ f"{Color.YELLOW}Client {client_id} reached "
+ f"request limit{Color.RESET}"
+ )
+ break
+
+ if stop_event.is_set(): # type: ignore
+ logger.info(
+ f"{Color.YELLOW}Client {client_id} received "
+ f"a termination signal{Color.RESET}"
+ )
+ break
+
+ while (
+ len(active_convs) < args.max_active_conversations
+ and task_queue_empty is False
+ ):
+ # Get a new conversation from the task queue
+ conv_id, messages = task_queue.get()
+
+ if conv_id is TERM_SIGNAL:
+ task_queue_empty = True
+ break
+
+ if args.skip_first_turn:
+ # Skip the first turn (both user and assistant),
+ # relevant if warmup was enabled.
+ # Default turns_count[conv_id] will be zero if conv_id
+ # was never inserted/updated in turns_count.
+ turns_count[conv_id] += 2
+
+ if turns_count[conv_id] < len(messages):
+ # Add new conversation
+ active_convs[conv_id] = messages
+ conv_id_queue.append(conv_id)
+
+ if args.verbose:
+ logger.info(
+ f"{Color.GREEN}Client {client_id} will use conversation ID {conv_id} (active conversations {len(active_convs)}){Color.RESET}" # noqa: E501
+ )
+
+ elif args.verbose:
+ # No more messages (conversation finished during the warmup)
+ logger.info(
+ f"{Color.YELLOW}Client {client_id} will not use conversation ID {conv_id} (all {len(messages)} messages already sent){Color.RESET}" # noqa: E501
+ )
+
+ if len(active_convs) == 0 or task_queue_empty:
+ logger.info(
+ f"{Color.YELLOW}Client {client_id} has no more work{Color.RESET}"
+ )
+ break
+
+ # Pick an active conversation for the next request
+ if args.conversation_sampling == ConversationSampling.ROUND_ROBIN:
+ conv_id = conv_id_queue.pop()
+ else:
+ # ConversationSampling.RANDOM
+ active_ids = list(active_convs.keys())
+ conv_id = random.choice(active_ids)
+
+ messages = active_convs[conv_id]
+ assert isinstance(messages, list) and len(messages) > 0
+
+ # Update the amount of messages to use
+ turns_count[conv_id] += 1
+ current_turn = turns_count[conv_id]
+
+ assert current_turn < len(messages), (
+ f"Turn number {current_turn} is invalid for conversation ID {conv_id}"
+ f" that has only {len(messages)} messages"
+ )
+
+ if args.verbose:
+ curr_time_sec: float = time.perf_counter()
+ time_since_last_turn: Union[str, float] = "N/A"
+ if conv_id in time_of_last_turn:
+ time_since_last_turn = round(
+ curr_time_sec - time_of_last_turn[conv_id], 3
+ )
+ logger.info(
+ f"Client {client_id} using conversation ID {conv_id} (turn: {current_turn}, time since last turn [sec]: {time_since_last_turn})" # noqa: E501
+ )
+ time_of_last_turn[conv_id] = curr_time_sec
+
+ success = True
+ try:
+ result = await send_turn(
+ session,
+ client_id,
+ conv_id,
+ messages,
+ current_turn,
+ tokenizer,
+ req_args,
+ args.print_content,
+ args.verify_output,
+ )
+ if result is not None:
+ result_queue.put(result)
+ else:
+ # None means that the request failed,
+ # and should not be added to the statistics.
+ success = False
+ num_failures += 1
+
+ logger.warning(
+ f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
+ )
+
+ # Remove the conversation (should not be used again)
+ active_convs.pop(conv_id)
+
+ except asyncio.exceptions.TimeoutError:
+ num_failures += 1
+ logger.exception(
+ f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
+ )
+ break # Exit gracefully instead of raising an error
+
+ except Exception:
+ num_failures += 1
+ logger.exception(
+ f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
+ )
+ break # Exit gracefully instead of raising an error
+
+ if success:
+ num_successes += 1
+
+ # Update the turns counter to include the LLM response
+ # The LLM response will be used as context for the next user turn
+ turns_count[conv_id] += 1
+
+ max_turns = len(messages)
+ if args.max_turns is not None:
+ # Limit the number of turns in the conversation
+ max_turns = min(args.max_turns, max_turns)
+
+ if turns_count[conv_id] >= max_turns:
+ # Conversation has no more turns (no longer active)
+ # save the updated conversation (with the LLM server's answer)
+ conv_queue.put((conv_id, active_convs.pop(conv_id)))
+ if args.verbose:
+ logger.info(
+ f"{Color.GREEN}Client {client_id} finished "
+ f"conversation ID {conv_id}{Color.RESET}"
+ )
+ else:
+ # Conversation is not finished, insert it at the back of the queue
+ conv_id_queue.appendleft(conv_id)
+
+ # Sleep between requests (if lambda is positive)
+ if args.request_rate > 0:
+ await poisson_sleep(args.request_rate, args.verbose)
+
+ # Send indication that the client is done
+ conv_queue.put((TERM_SIGNAL, TERM_SIGNAL))
+
+ logger.info(
+ f"{Color.CYAN}Client {client_id} is done "
+ f"({num_successes=}, {num_failures=}){Color.RESET}"
+ )
+
+
+def worker_function(
+ client_id: int,
+ tokenizer: AutoTokenizer,
+ client_args: ClientArgs,
+ req_args: RequestArgs,
+ stop_event: mp.Event, # type: ignore
+ task_queue: mp.Queue,
+ result_queue: mp.Queue,
+ conv_queue: mp.Queue,
+) -> None:
+ asyncio.run(
+ client_main(
+ client_args,
+ req_args,
+ client_id,
+ tokenizer,
+ stop_event,
+ task_queue,
+ result_queue,
+ conv_queue,
+ )
+ )
+
+
+def get_client_config(
+ args: argparse.Namespace, input_conv: ConversationsMap
+) -> tuple[ClientArgs, RequestArgs]:
+ if args.num_clients < 1:
+ raise ValueError("Number of clients must be a positive number")
+
+ if len(input_conv) < args.num_clients:
+ raise ValueError(
+ "Number of conversations must be equal or larger than the number of clients"
+ )
+
+ max_req_per_client: Optional[int] = None
+ if args.max_num_requests is not None:
+ # Max number of requests per client
+ req_per_client = args.max_num_requests // args.num_clients
+ if req_per_client < 1:
+ raise ValueError("Number of requests should be at least one per client")
+ max_req_per_client = req_per_client
+
+ max_active_conversations = args.max_active_conversations
+ if max_active_conversations is None:
+ # Each client will have only one active conversation at a time
+ max_active_conversations = args.num_clients
+
+ if max_active_conversations > len(input_conv):
+ raise ValueError(
+ f"Max active conversations {max_active_conversations} "
+ "must be equal or less than the total number of conversations"
+ )
+
+ # Max number of active conversations per client
+ max_active_conv_per_client = max_active_conversations // args.num_clients
+ if max_active_conv_per_client < 1:
+ raise ValueError(
+ f"Max active conversations {max_active_conversations} "
+ "must be equal or greater than the number of clients"
+ )
+
+ # Skip the first user turn (as part of the warmup)
+ skip_first_turn = args.warmup_step
+
+ # Common arguments for all clients
+ client_args = ClientArgs(
+ seed=args.seed,
+ max_num_requests=max_req_per_client,
+ skip_first_turn=skip_first_turn,
+ max_turns=args.max_turns,
+ max_active_conversations=max_active_conv_per_client,
+ verbose=args.verbose,
+ print_content=args.print_content,
+ verify_output=args.verify_output,
+ conversation_sampling=args.conversation_sampling,
+ request_rate=args.request_rate,
+ )
+
+ if args.limit_min_tokens > 0 or args.limit_max_tokens > 0:
+ if args.limit_min_tokens < 1 or args.limit_max_tokens < 1:
+ raise ValueError(
+ "Invalid min/max tokens limits (both limits should be provided)"
+ )
+ if args.limit_min_tokens > args.limit_max_tokens:
+ raise ValueError(
+ "Invalid min/max tokens limits (min should not be larger than max)"
+ )
+
+ # Arguments for API requests
+ chat_url = f"{args.url}/v1/chat/completions"
+ req_args = RequestArgs(
+ chat_url=chat_url,
+ model=args.model,
+ stream=not args.no_stream,
+ limit_min_tokens=args.limit_min_tokens,
+ limit_max_tokens=args.limit_max_tokens,
+ )
+
+ return client_args, req_args
+
+
+async def main_mp(
+ client_args: ClientArgs,
+ req_args: RequestArgs,
+ bench_args: BenchmarkArgs,
+ tokenizer: AutoTokenizer,
+ input_conv: ConversationsMap,
+) -> tuple[ConversationsMap, list[RequestStats]]:
+ # An event that will trigger graceful termination of all the clients
+ stop_event = mp.Event()
+
+ # Queue for input conversations (from the input file/dataset)
+ task_queue: mp.Queue = mp.Queue()
+
+ # Queue for client measurements (TTFT, TPOT, etc. for each request)
+ result_queue: mp.Queue = mp.Queue()
+
+ # Queue for output conversations (with the LLM answers, sent by the server)
+ conv_queue: mp.Queue = mp.Queue()
+ output_conv: ConversationsMap = {}
+ client_metrics: list[RequestStats] = []
+
+ # Start all clients
+ start_time = time.perf_counter_ns()
+ logger.info(f"{Color.GREEN}Starting {bench_args.num_clients} clients{Color.RESET}")
+
+ clients = []
+ for client_id in range(bench_args.num_clients):
+ client = mp.Process(
+ name=f"client_{client_id}",
+ target=worker_function,
+ args=(
+ client_id,
+ tokenizer,
+ client_args,
+ req_args,
+ stop_event,
+ task_queue,
+ result_queue,
+ conv_queue,
+ ),
+ )
+ clients.append(client)
+ client.start()
+
+ # Submit all the input conversations as tasks for the clients
+ for conv_id, messages in input_conv.items():
+ task_queue.put((conv_id, messages))
+
+ # Add termination signals for clients
+ for _ in range(bench_args.num_clients):
+ task_queue.put((TERM_SIGNAL, TERM_SIGNAL))
+
+ # Collect the updated conversations from all clients
+ num_clients_finished = 0
+ total_convs = len(input_conv)
+
+ debug_stats = DebugStats(logger, min(15 * bench_args.num_clients, 500))
+
+ while num_clients_finished < bench_args.num_clients:
+ # Collect updated conversation
+ conv_id, messages = conv_queue.get()
+
+ # Collect results (measurements)
+ while not result_queue.empty():
+ new_data = result_queue.get()
+ client_metrics.append(new_data)
+ debug_stats.update(new_data)
+
+ if conv_id is TERM_SIGNAL:
+ num_clients_finished += 1
+ logger.info(
+ f"{Color.CYAN}{num_clients_finished} out of "
+ f"{bench_args.num_clients} clients finished{Color.RESET}"
+ )
+
+ if bench_args.early_stop and not stop_event.is_set():
+ # Once one client finished, stop all other clients.
+ # there is no reason to continue the benchmark with fewer clients.
+ logger.info(
+ f"{Color.YELLOW}Sending termination signal to clients{Color.RESET}"
+ )
+ stop_event.set()
+ else:
+ output_conv[conv_id] = messages
+
+ finished_convs = len(output_conv)
+ percent = finished_convs / total_convs
+
+ # Tuned to control the print rate (can be changed if required)
+ print_cycle = max(3, int(bench_args.num_clients / 4))
+
+ if finished_convs % print_cycle == 0:
+ runtime_sec = nanosec_to_sec(time.perf_counter_ns() - start_time)
+ logger.info(
+ f"{Color.CYAN}Finished {finished_convs} out of {total_convs} conversations ({percent:.0%}), " # noqa: E501
+ f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501
+ )
+
+ rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3)
+ if len(client_metrics) < (5 * bench_args.num_clients):
+ # Do not estimate the RPS if the number of samples is very low
+ # (threshold can be tuned if needed)
+ rps = "N/A"
+
+ runtime_left_sec: Union[str, float] = round(
+ (runtime_sec / finished_convs) * (total_convs - finished_convs), 3
+ )
+ if percent < 0.05:
+ # If less than 5% of the conversations were not finished,
+ # the estimation will probably be very inaccurate
+ # (threshold can be tuned if needed).
+ runtime_left_sec = "N/A"
+
+ logger.info(
+ f"{Color.CYAN}Estimated req/sec {rps}, estimated runtime left {runtime_left_sec} sec{Color.RESET}" # noqa: E501
+ )
+ debug_stats.print()
+
+ logger.info(
+ f"{Color.CYAN}All {bench_args.num_clients} clients finished{Color.RESET}"
+ )
+
+ # At this point all the clients finished,
+ # collect results (TTFT, TPOT, etc.) from all the clients.
+ # This needs to happens before calling join on the clients
+ # (result_queue should be emptied).
+ while not result_queue.empty():
+ client_metrics.append(result_queue.get())
+
+ logger.info(f"Collected {len(client_metrics)} samples from all the clients")
+
+ # Wait for all clients to finish
+ for client in clients:
+ logger.info(
+ f"{Color.CYAN}Waiting for client {client.name} "
+ f"(is alive: {client.is_alive()}){Color.RESET}"
+ )
+
+ client.join(timeout=120)
+
+ if client.is_alive():
+ logger.warning(
+ f"{Color.YELLOW}Client {client.name} will be terminated{Color.RESET}"
+ )
+ client.terminate()
+
+ exitcode = client.exitcode
+ if exitcode != 0:
+ logger.error(
+ f"{Color.RED}Client {client.name} exited "
+ f"with exit code {exitcode}{Color.RESET}"
+ )
+
+ logger.info(
+ f"All {bench_args.num_clients} clients exited (successfully "
+ f"finished {len(output_conv)} out of {total_convs} conversations)"
+ )
+
+ # Queues should be closed, required to avoid hang at interpreter shutdown
+ unfinished_tasks = 0
+ while not task_queue.empty():
+ task_queue.get()
+ unfinished_tasks += 1
+
+ if unfinished_tasks > 0:
+ # Can happen if not all tasks (conversations) have finished.
+ # May happen if --max-num-requests was used,
+ # or if an error occurred in one of the clients.
+ logger.debug(f"Discarding {unfinished_tasks} unfinished tasks")
+
+ task_queue.close()
+ task_queue.join_thread()
+
+ result_queue.close()
+ result_queue.join_thread()
+
+ conv_queue.close()
+ conv_queue.join_thread()
+
+ return output_conv, client_metrics
+
+
+def get_filename_with_timestamp(label: str, extension: str) -> str:
+ time_now = datetime.now()
+ timestamp = time_now.strftime("%d-%m-%Y_%H-%M-%S")
+ filename = f"{label}__{timestamp}.{extension}"
+ return filename
+
+
+def process_statistics(
+ client_metrics: list[RequestStats],
+ warmup_percentages: list[float],
+ test_params: dict,
+ verbose: bool,
+ gen_conv_args: Optional[GenConvArgs] = None,
+ excel_output: bool = False,
+) -> None:
+ if len(client_metrics) == 0:
+ logger.info("No samples to process")
+ return
+
+ logger.info(f"Processing {len(client_metrics)} samples...")
+
+ raw_data = pd.DataFrame(client_metrics)
+
+ if verbose:
+ # Calculate the time between user turns in each conversation (in a new column)
+ raw_data = raw_data.sort_values(by=["conversation_id", "start_time_ms"])
+ raw_data["time_between_user_turns_sec"] = raw_data.groupby("conversation_id")[
+ "start_time_ms"
+ ].diff()
+
+ # Convert milliseconds to seconds
+ raw_data["time_between_user_turns_sec"] = (
+ raw_data["time_between_user_turns_sec"] / 1000.0
+ )
+
+ # Final raw data should be sorted by time
+ raw_data = raw_data.sort_values(by=["start_time_ms"])
+ raw_data["end_time_ms"] = raw_data["start_time_ms"] + raw_data["latency_ms"]
+
+ percentiles = [0.25, 0.5, 0.75, 0.9]
+
+ # Add more percentiles if there are enough samples
+ if len(raw_data) >= 100:
+ percentiles.append(0.99)
+
+ if len(raw_data) >= 1000:
+ percentiles.append(0.999)
+
+ if len(raw_data) >= 10000:
+ percentiles.append(0.9999)
+
+ # Set precision for numbers in the output text (the dataframes)
+ pd.set_option("display.precision", 2)
+
+ # Exclude parameters from RequestStats
+ exclude = [
+ "start_time_ms",
+ "end_time_ms",
+ "output_num_first_chunk_tokens",
+ "approx_cached_percent",
+ "conversation_id",
+ "client_id",
+ ]
+
+ print(TEXT_SEPARATOR)
+ print(f"{Color.YELLOW}Parameters:{Color.RESET}")
+ for k, v in test_params.items():
+ print(f"{k}={v}")
+
+ # conversations generation parameters
+ if gen_conv_args is not None:
+ gen_params = {
+ "text_files": ", ".join(gen_conv_args.text_files),
+ "input_num_turns": str(gen_conv_args.input_num_turns),
+ "input_common_prefix_num_tokens": str(
+ gen_conv_args.input_common_prefix_num_tokens
+ ),
+ "input_prefix_num_tokens": str(gen_conv_args.input_prefix_num_tokens),
+ "input_num_tokens": str(gen_conv_args.input_num_tokens),
+ "output_num_tokens": str(gen_conv_args.output_num_tokens),
+ }
+
+ print(f"{Color.YELLOW}Conversations Generation Parameters:{Color.RESET}")
+ for k, v in gen_params.items():
+ print(f"{k}={v}")
+
+ print(TEXT_SEPARATOR)
+
+ params_list = []
+ df_list = []
+ for percent in warmup_percentages:
+ # Select samples from the end (tail) of the dataframe
+ warmup_count = int(percent * len(raw_data))
+ tail_count = len(raw_data) - warmup_count
+ if tail_count == 0:
+ # No reason to process if the count of samples is zero
+ break
+
+ df = raw_data.tail(tail_count)
+
+ # Runtime is the diff between the end of the last request
+ # and the start of the first request
+ runtime_sec = df["end_time_ms"].iloc[-1] - df["start_time_ms"].iloc[0]
+
+ # Convert milliseconds to seconds
+ runtime_sec = runtime_sec / 1000.0
+ requests_per_sec = float(len(df)) / runtime_sec
+
+ params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec}
+
+ # Generate a summary of relevant metrics (and drop irrelevant data)
+ df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose()
+
+ # List for Excel file
+ params_list.append(params)
+ df_list.append(df)
+
+ # Print the statistics summary
+ if percent > 0 or len(warmup_percentages) > 1:
+ print(
+ f"{Color.YELLOW}Statistics summary "
+ f"(assuming {percent:.0%} warmup samples):{Color.RESET}"
+ )
+ else:
+ print(f"{Color.YELLOW}Statistics summary:{Color.RESET}")
+
+ for k, v in params.items():
+ if isinstance(v, float):
+ print(f"{k} = {v:.3f}")
+ else:
+ print(f"{k} = {v}")
+ print(TEXT_SEPARATOR)
+ print(df)
+ print(TEXT_SEPARATOR)
+
+ if excel_output:
+ prefix = f"statistics_{test_params['num_clients']}_clients"
+ filename = get_filename_with_timestamp(prefix, "xlsx")
+
+ with pd.ExcelWriter(filename, engine="xlsxwriter") as writer:
+ startrow = 0
+ test_params_df = pd.DataFrame([test_params])
+ test_params_df.to_excel(
+ writer, sheet_name="Summary", index=False, startrow=startrow
+ )
+ startrow += len(test_params_df) + 3
+
+ if gen_conv_args is not None:
+ gen_params_df = pd.DataFrame([gen_params])
+ gen_params_df.to_excel(
+ writer, sheet_name="Summary", index=False, startrow=(startrow - 1)
+ )
+ startrow += len(gen_params_df) + 3
+
+ for params, df_stats in zip(params_list, df_list):
+ df_params = pd.DataFrame([params])
+ df_params.to_excel(
+ writer, sheet_name="Summary", index=False, startrow=startrow
+ )
+ startrow += len(df_params) + 2
+ df_stats.to_excel(
+ writer, sheet_name="Summary", index=True, startrow=startrow
+ )
+ startrow += len(df_stats) + 3
+
+ raw_data.to_excel(writer, sheet_name="Raw data", index=False, startrow=0)
+
+ logger.info(
+ f"{Color.GREEN}Client metrics exported to file: {filename}{Color.RESET}"
+ )
+
+
+async def get_server_info(url: str) -> None:
+ logger.info(f"{Color.BLUE}Collecting information from server: {url}{Color.RESET}")
+ async with aiohttp.ClientSession() as session:
+ # Get server version (not mandatory, "version" endpoint may not exist)
+ url_version = f"{url}/version"
+ async with session.get(url_version) as response:
+ if HTTPStatus(response.status) == HTTPStatus.OK:
+ text = await response.text()
+ logger.info(f"{Color.BLUE}Server version: {text}{Color.RESET}")
+
+ # Get available models
+ url_models = f"{url}/v1/models"
+ async with session.get(url_models) as response:
+ if HTTPStatus(response.status) == HTTPStatus.OK:
+ text = await response.text()
+ logger.info(f"{Color.BLUE}Models:{Color.RESET}")
+ models_data = json.loads(text)
+ models_list = models_data["data"]
+ for model in models_list:
+ model_id = model["id"]
+ max_model_len = model.get("max_model_len", "N/A")
+ logger.info(
+ f"{Color.BLUE}\t{model_id=}, {max_model_len=}{Color.RESET}"
+ )
+ else:
+ logger.info(f"{Color.RED}Failed to get models{Color.RESET}")
+
+
+async def main() -> None:
+ parser = argparse.ArgumentParser(
+ prog="Benchmark serving with multi-turn conversations",
+ description="Benchmark online inference using REST API",
+ )
+ parser.add_argument("--version", action="version", version="%(prog)s 1.0")
+
+ parser.add_argument(
+ "-i",
+ "--input-file",
+ type=str,
+ required=True,
+ help="Input JSON file with ShareGPT conversations or "
+ "configuration file for generation of synthetic conversations",
+ )
+ parser.add_argument(
+ "-o",
+ "--output-file",
+ type=str,
+ default=None,
+ help="Output JSON file containing conversations with updated assistant answers",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=0,
+ help="Seed for random number generators (default: 0)",
+ )
+ parser.add_argument(
+ "-m", "--model", type=str, required=True, help="Path of the LLM model"
+ )
+ parser.add_argument(
+ "-u",
+ "--url",
+ type=str,
+ default="http://localhost:8000",
+ help="Base URL for the LLM API server",
+ )
+
+ parser.add_argument(
+ "-p",
+ "--num-clients",
+ type=int,
+ default=1,
+ help="Number of clients that will send requests in parallel",
+ )
+ parser.add_argument(
+ "-k",
+ "--max-active-conversations",
+ type=int,
+ default=None,
+ help="Max number of active conversations at a time (for all clients)",
+ )
+ parser.add_argument(
+ "-n",
+ "--max-num-requests",
+ type=int,
+ default=None,
+ help="Max number of requests to send (total for all clients)",
+ )
+
+ parser.add_argument(
+ "--warmup-step",
+ default=False,
+ action="store_true",
+ help="Run a warmup step (using only the first turn of every conversation), "
+ "measurements will not be included in the final benchmark results",
+ )
+
+ parser.add_argument(
+ "--max-turns",
+ type=int,
+ default=None,
+ help="Maximum number of turns/messages per conversation, "
+ "includes both user and assistant messages "
+ "(a positive number, e.g: 2, 4, 6, etc.), disabled by default",
+ )
+ parser.add_argument(
+ "--no-early-stop",
+ default=False,
+ action="store_true",
+ help="By default, the benchmark will stop if at least one client exits."
+ " Use this flag to disable this behavior",
+ )
+
+ parser.add_argument(
+ "--limit-max-tokens",
+ type=int,
+ default=NUM_TOKENS_FROM_DATASET,
+ help="Set max_tokens for the output token count of each request "
+ "(must also set --limit-min-tokens). "
+ "Overrides output token count from the input dataset. "
+ "Use a negative value to disable this limit.",
+ )
+ parser.add_argument(
+ "--limit-min-tokens",
+ type=int,
+ default=NUM_TOKENS_FROM_DATASET,
+ help="Set min_tokens for the output token count of each request "
+ "(must also set --limit-max-tokens). "
+ "Overrides output token count from the input dataset. "
+ "Use a negative value to disable this limit.",
+ )
+
+ parser.add_argument(
+ "--request-rate",
+ type=float,
+ default=0,
+ help="Expected request rate (Poisson process) per client in requests/sec."
+ "Set to 0 for no delay between requests.",
+ )
+ parser.add_argument(
+ "--conversation-sampling",
+ type=ConversationSampling,
+ choices=list(ConversationSampling),
+ default=ConversationSampling.ROUND_ROBIN,
+ help=(
+ "Strategy for selecting which conversation to use for the next request. "
+ "Options: 'round_robin' (cycle through conversations), "
+ "'random' (pick randomly)."
+ ),
+ )
+ parser.add_argument(
+ "--verify-output",
+ default=False,
+ action="store_true",
+ help="Verify the LLM output (compare to the answers in the input JSON file)",
+ )
+
+ parser.add_argument(
+ "--no-stream",
+ default=False,
+ action="store_true",
+ help="Disable stream/streaming mode (set 'stream' to False in the API request)",
+ )
+
+ parser.add_argument(
+ "-e",
+ "--excel-output",
+ default=False,
+ action="store_true",
+ help="Export summary to Excel file (optional)",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ default=False,
+ action="store_true",
+ help="Enable verbose output",
+ )
+ parser.add_argument(
+ "--print-content",
+ default=False,
+ action="store_true",
+ help="Print the user prompts and the server's answers",
+ )
+
+ parser.add_argument(
+ "--warmup-percentages",
+ type=str,
+ default="0%",
+ help="Ignore the first X samples as warmup (X is a percentage)."
+ " A comma separated list of percentages can be used "
+ "(for example: --warmup-percentages=0%%,50%%)",
+ )
+
+ args = parser.parse_args()
+
+ logger.info(args)
+
+ logger.info(f"{Color.GREEN}Input parameters:{Color.RESET}")
+ logger.info(f"url={args.url}")
+ logger.info(f"model={args.model}")
+ logger.info(f"num_clients={args.num_clients}")
+
+ if args.verify_output:
+ logger.info(f"{Color.PURPLE}Verify is enabled{Color.RESET}")
+
+ # Calculate the amount of samples to filter (as warmup samples/measurements).
+ try:
+ warmup_percentages: list[float] = [0.0]
+ if not args.warmup_step:
+ # Warmup percentage can be used only if the warmup step was used
+ warmup_strings: list[str] = args.warmup_percentages.split(",")
+ warmup_strings = [x.replace("%", "") for x in warmup_strings]
+ warmup_percentages = [float(x) / 100 for x in warmup_strings]
+
+ # Check for valid range (0 to 1)
+ for p in warmup_percentages:
+ assert p >= 0.0 and p < 1.0
+
+ # Sort from high to low warmup percentage
+ warmup_percentages.sort()
+
+ logger.info(
+ f"Warmup percentages (percentage of samples): {warmup_percentages}"
+ )
+
+ except Exception:
+ raise ValueError(
+ f"Invalid --warmup-percentage={args.warmup_percentage}"
+ ) from None
+
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ if not os.path.exists(args.model):
+ raise OSError(f"Path does not exist: {args.model}")
+ logger.info("Loading tokenizer")
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
+
+ await get_server_info(args.url)
+
+ # Load the input file (either conversations of configuration file)
+ logger.info(f"Reading input file: {args.input_file}")
+ with open(args.input_file) as f:
+ input_data = json.load(f)
+
+ gen_conv_args = None
+ if isinstance(input_data, list):
+ # The conversations are stored as a list of dicts
+ logger.info(f"Found {len(input_data)} items in the input file")
+
+ # Convert the list to a ConversationsMap
+ conversations = conversations_list_to_dict(input_data)
+
+ elif isinstance(input_data, dict):
+ # The input file is a configuration file
+ # (type is determined by the field 'filetype')
+ if "filetype" not in input_data:
+ raise Exception(
+ f"Input file {args.input_file} is invalid (missing 'filetype')"
+ )
+
+ logger.info(f"Using input file with filetype: {input_data['filetype']}")
+
+ gen_conv_args = parse_input_json_file(input_data)
+
+ # Disable warning from "huggingface/tokenizers"
+ # (when using python multiprocessing and tokenizers)
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ # Generate synthetic conversations
+ conversations = generate_conversations(gen_conv_args, tokenizer)
+
+ else:
+ raise Exception(f"Input file {args.input_file} is invalid")
+
+ if args.max_turns is not None:
+ if args.max_turns < 1:
+ raise ValueError("Max turns must be a positive number")
+ logger.info(
+ f"{Color.PURPLE}Max turns per conversation "
+ f"is limited to {args.max_turns}{Color.RESET}"
+ )
+
+ # Create benchmark configurations
+ client_args, req_args = get_client_config(args, conversations)
+
+ bench_args = BenchmarkArgs(
+ url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop
+ )
+
+ # Warm-up step
+ if args.warmup_step:
+ # Only send a single user prompt from every conversation.
+ # max_active_conversations must be 1,
+ # otherwise the clients may exit after sending a single request
+ # (because the task queue is empty).
+ warmup_client_args = client_args._replace(
+ skip_first_turn=False, max_turns=1, max_active_conversations=1
+ )
+
+ # Early stop should be disabled,
+ # all clients should finish their work before exiting
+ warmup_bench_args = bench_args._replace(early_stop=False)
+
+ logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}")
+ conversations, _ = await main_mp(
+ warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations
+ )
+ logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}")
+
+ # Run the benchmark
+ start_time = time.perf_counter_ns()
+ client_convs, client_metrics = await main_mp(
+ client_args, req_args, bench_args, tokenizer, conversations
+ )
+ total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time)
+
+ # Calculate requests per second
+ total_runtime_sec = total_runtime_ms / 1000.0
+ rps = len(client_metrics) / total_runtime_sec
+ logger.info(
+ f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec"
+ f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}"
+ )
+
+ # Benchmark parameters
+ params = {
+ "model": args.model,
+ "num_clients": args.num_clients,
+ "num_conversations": len(conversations),
+ "active_conversations": args.max_active_conversations,
+ "seed": args.seed,
+ }
+
+ if args.limit_min_tokens > 0:
+ params["min_tokens"] = args.limit_min_tokens
+
+ if args.limit_max_tokens > 0:
+ params["max_tokens"] = args.limit_max_tokens
+
+ # Process and print statistics (and save excel file with the statistics)
+ process_statistics(
+ client_metrics,
+ test_params=params,
+ warmup_percentages=warmup_percentages,
+ verbose=args.verbose,
+ gen_conv_args=gen_conv_args,
+ excel_output=args.excel_output,
+ )
+
+ if args.output_file is not None:
+ # Write a JSON file with the updated conversations
+ # The "assistant" content will contain the answers from the tested LLM
+ output_data: ShareGptConversations = conversations_dict_to_list(client_convs)
+ logger.info(
+ f"{Color.GREEN}Writing conversations file: {args.output_file}{Color.RESET}"
+ )
+ with open(args.output_file, "w") as f:
+ json.dump(output_data, f, indent=4)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/benchmarks/multi_turn/convert_sharegpt_to_openai.py b/benchmarks/multi_turn/convert_sharegpt_to_openai.py
new file mode 100644
index 0000000000000..c3622c99a2e53
--- /dev/null
+++ b/benchmarks/multi_turn/convert_sharegpt_to_openai.py
@@ -0,0 +1,354 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Download dataset from:
+https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json
+
+Convert to OpenAI API:
+export INPUT_FILE=sharegpt_20230401_clean_lang_split.json
+python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128
+"""
+
+import argparse
+import json
+import random
+from statistics import mean
+from typing import Any, Optional
+
+import pandas as pd # type: ignore
+import tqdm # type: ignore
+from transformers import AutoTokenizer # type: ignore
+
+
+def has_non_english_chars(text: str) -> bool:
+ return not text.isascii()
+
+
+def content_is_valid(
+ content: str, min_content_len: Optional[int], max_content_len: Optional[int]
+) -> bool:
+ if min_content_len and len(content) < min_content_len:
+ return False
+
+ if max_content_len and len(content) > max_content_len:
+ return False
+
+ return has_non_english_chars(content)
+
+
+def print_stats(
+ conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None
+) -> None:
+ # Collect statistics
+ stats = []
+
+ print("\nCollecting statistics...")
+ for item in tqdm.tqdm(conversations):
+ # item has "id" and "messages"
+ messages = item["messages"]
+
+ user_turns = 0
+ assistant_turns = 0
+ user_words = 0
+ assistant_words = 0
+ conv_chars = 0
+
+ user_tokens: list[int] = []
+ assistant_tokens: list[int] = []
+
+ for m in messages:
+ content = m["content"]
+ conv_chars += len(content)
+ content_num_words = content.count(" ") + 1
+
+ num_tokens = 0
+ if tokenizer:
+ num_tokens = len(tokenizer(m["content"]).input_ids)
+
+ if m["role"] == "user":
+ user_turns += 1
+ user_words += content_num_words
+ if tokenizer:
+ user_tokens.append(num_tokens)
+
+ elif m["role"] == "assistant":
+ assistant_turns += 1
+ assistant_words += content_num_words
+ if tokenizer:
+ assistant_tokens.append(num_tokens)
+
+ # assert user_turns == assistant_turns, \
+ # f"Invalid conversation ID {item['id']}"
+
+ conv_words = user_words + assistant_words
+ item_stats = {
+ "user_turns": user_turns,
+ "assistant_turns": assistant_turns,
+ "user_words": user_words,
+ "assistant_words": assistant_words,
+ "conv_turns": len(messages),
+ "conv_words": conv_words,
+ "conv_characters": conv_chars,
+ }
+
+ if len(user_tokens) > 0:
+ item_stats["user_tokens"] = int(mean(user_tokens))
+
+ if len(assistant_tokens) > 0:
+ item_stats["assistant_tokens"] = int(mean(assistant_tokens))
+
+ stats.append(item_stats)
+
+ print("\nStatistics:")
+ percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
+ df = pd.DataFrame(stats)
+ print(df.describe(percentiles=percentiles).transpose())
+
+
+def convert_sharegpt_to_openai(
+ seed: int,
+ input_file: str,
+ output_file: str,
+ max_items: Optional[int],
+ min_content_len: Optional[int] = None,
+ max_content_len: Optional[int] = None,
+ min_turns: Optional[int] = None,
+ max_turns: Optional[int] = None,
+ model: Optional[str] = None,
+) -> None:
+ if min_turns and max_turns:
+ assert min_turns <= max_turns
+
+ if min_content_len and max_content_len:
+ # Verify that min is not larger than max if both were given
+ assert min_content_len <= max_content_len
+
+ print(
+ f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=},"
+ f" {max_content_len=}, {min_turns=}, {max_turns=}\n"
+ )
+
+ random.seed(seed)
+
+ tokenizer = None
+ if model is not None:
+ print(f"Loading tokenizer from: {model}")
+ tokenizer = AutoTokenizer.from_pretrained(model)
+
+ # Read the ShareGPT JSON file
+ print(f"Reading file: {input_file}")
+ with open(input_file, encoding="utf-8") as f:
+ # Should be a list of dicts
+ # Each dict should have "id" (string) and "conversations" (list of dicts)
+ sharegpt_data = json.load(f)
+
+ assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts"
+
+ print(f"Total items in input file: {len(sharegpt_data):,}")
+
+ print(f"Shuffling dataset with seed {seed}")
+ random.shuffle(sharegpt_data)
+
+ # Map conversation ID to the all the messages
+ conversation_parts: dict[str, list[Any]] = {}
+
+ for item in tqdm.tqdm(sharegpt_data):
+ assert "id" in item, "Missing key 'id'"
+ assert "conversations" in item, "Missing key 'conversations'"
+
+ # Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.)
+ conv_id, _ = item["id"].split("_")
+ new_turns = item["conversations"]
+
+ if conv_id not in conversation_parts:
+ # Start new conversation
+ conversation_parts[conv_id] = []
+ elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0:
+ prev_turns = conversation_parts[conv_id][-1]
+ if prev_turns[-1]["from"] == new_turns[0]["from"]:
+ new_turns = new_turns[1:]
+
+ if len(new_turns) > 0:
+ # We assume that parts are in order in the ShareGPT dataset
+ conversation_parts[conv_id].append(new_turns)
+
+ dataset: list[dict[str, Any]] = []
+ for conv_id, conv_parts in conversation_parts.items():
+ new_item = {"id": conv_id}
+
+ conversations: list[dict[str, str]] = []
+
+ # Merge all parts
+ for conv_part in conv_parts:
+ conversations.extend(conv_part)
+
+ if len(conversations) > 0:
+ new_item["conversations"] = conversations
+ dataset.append(new_item)
+
+ print(f"Total unique conversations (IDs) in input file: {len(dataset):,}")
+
+ # Final output data
+ final_openai_dataset: list[dict] = []
+
+ # Filter conversations from the ShareGPT dataset and convert to OpenAI format
+ for item in tqdm.tqdm(dataset):
+ messages: list[dict] = []
+
+ assert "id" in item, "Missing key 'id'"
+ assert "conversations" in item, "Missing key 'conversations'"
+
+ conv_id = item["id"]
+ conversations = item["conversations"]
+
+ if min_turns is not None and len(conversations) < min_turns:
+ # Skip short conversations
+ continue
+
+ # Convert each message in the conversation, up to max_turns if specified
+ for i, turn in enumerate(conversations):
+ assert "from" in turn and "value" in turn, (
+ f"Invalid conversation ID {conv_id} - missing 'from' or 'value'"
+ )
+
+ role = None
+ turn_from = turn["from"]
+
+ if turn_from in {"human", "user"}:
+ role = "user"
+ elif turn_from in {"gpt", "bing", "chatgpt", "bard"}:
+ role = "assistant"
+ elif turn_from == "system":
+ role = "system"
+
+ assert role is not None, (
+ f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid"
+ )
+
+ if i == 0 and role != "user":
+ # If the first message is from assistant (gpt), skip it.
+ # this happens when the conversation is a follow-up
+ # to a previous conversation (from the same user).
+ continue
+
+ if max_turns is not None and i >= max_turns:
+ break
+
+ # Convert message to OpenAI format (with "role" and "content")
+ content = turn["value"]
+ messages.append({"role": role, "content": content})
+
+ # Add the converted conversation to the OpenAI format
+ if len(messages) > 0:
+ valid_messages = True
+
+ # First turn should always be from the user
+ user_turn = True
+
+ for m in messages:
+ # Make sure that turns alternate between user and assistant
+ if (user_turn and m["role"] != "user") or (
+ not user_turn and m["role"] != "assistant"
+ ):
+ valid_messages = False
+ break
+
+ user_turn = not user_turn
+
+ content = m["content"]
+ valid_messages = content_is_valid(
+ content, min_content_len, max_content_len
+ )
+ if not valid_messages:
+ break
+
+ if valid_messages is True:
+ final_openai_dataset.append({"id": conv_id, "messages": messages})
+
+ assert len(final_openai_dataset) > 0, "Final number of conversations is zero"
+
+ print_stats(final_openai_dataset)
+
+ print_stats_again = False
+ if max_items is not None and len(final_openai_dataset) > max_items:
+ print(f"\n\nSampling {max_items} items from the dataset...")
+ print_stats_again = True
+ final_openai_dataset = random.sample(final_openai_dataset, max_items)
+
+ if print_stats_again:
+ # Print stats after the dataset changed
+ print_stats(final_openai_dataset, tokenizer)
+
+ # Write the converted data to a new JSON file
+ final_size = len(final_openai_dataset)
+ print(f"\nTotal conversations converted (after filtering): {final_size:,}")
+ print(f"\nWriting file: {output_file}")
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Convert ShareGPT dataset to OpenAI API format"
+ )
+ parser.add_argument("input_file", help="Path to the input ShareGPT JSON file")
+ parser.add_argument(
+ "output_file", help="Path to the output OpenAI format JSON file"
+ )
+ parser.add_argument(
+ "--seed", type=int, default=0, help="Seed for random number generators"
+ )
+ parser.add_argument(
+ "--max-items",
+ type=int,
+ default=None,
+ help="Maximum number of items in the output file",
+ )
+ parser.add_argument(
+ "--min-turns",
+ type=int,
+ default=None,
+ help="Minimum number of turns per conversation",
+ )
+ parser.add_argument(
+ "--max-turns",
+ type=int,
+ default=None,
+ help="Maximum number of turns per conversation",
+ )
+ parser.add_argument(
+ "--min-content-len",
+ type=int,
+ default=None,
+ help="Min number of characters in the messages' content",
+ )
+ parser.add_argument(
+ "--max-content-len",
+ type=int,
+ default=None,
+ help="Max number of characters in the messages' content",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=None,
+ help="LLM model, only the tokenizer will be used",
+ )
+
+ args = parser.parse_args()
+
+ convert_sharegpt_to_openai(
+ args.seed,
+ args.input_file,
+ args.output_file,
+ args.max_items,
+ args.min_content_len,
+ args.max_content_len,
+ args.min_turns,
+ args.max_turns,
+ args.model,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks/multi_turn/generate_multi_turn.json b/benchmarks/multi_turn/generate_multi_turn.json
new file mode 100644
index 0000000000000..274d03c2bdb2b
--- /dev/null
+++ b/benchmarks/multi_turn/generate_multi_turn.json
@@ -0,0 +1,35 @@
+{
+ "filetype": "generate_conversations",
+ "num_conversations": 24,
+ "text_files": ["pg1184.txt"],
+ "print_stats": false,
+ "prompt_input": {
+ "num_turns": {
+ "distribution": "uniform",
+ "min": 12,
+ "max": 18
+ },
+ "common_prefix_num_tokens": {
+ "distribution": "constant",
+ "value": 500
+ },
+ "prefix_num_tokens": {
+ "distribution": "lognormal",
+ "mean": 6,
+ "sigma": 4,
+ "max": 1500
+ },
+ "num_tokens": {
+ "distribution": "uniform",
+ "min": 120,
+ "max": 160
+ }
+ },
+ "prompt_output": {
+ "num_tokens": {
+ "distribution": "uniform",
+ "min": 80,
+ "max": 120
+ }
+ }
+}
\ No newline at end of file
diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt
new file mode 100644
index 0000000000000..f0e1935914a14
--- /dev/null
+++ b/benchmarks/multi_turn/requirements.txt
@@ -0,0 +1,5 @@
+numpy>=1.24
+pandas>=2.0.0
+aiohttp>=3.10
+transformers>=4.46
+xlsxwriter>=3.2.1
\ No newline at end of file
diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake
index 6291475164baa..ee6768bce26ca 100644
--- a/cmake/external_projects/flashmla.cmake
+++ b/cmake/external_projects/flashmla.cmake
@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
- GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
+ GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@@ -37,9 +37,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
- ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
- ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
- ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
+ ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
+ ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
+ ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include
diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake
index 59b99e9e207a8..d24d8e8e5e795 100644
--- a/cmake/external_projects/vllm_flash_attn.cmake
+++ b/cmake/external_projects/vllm_flash_attn.cmake
@@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
- GIT_TAG 6dbc6e011a3ebe9349eeb74578940dd7095436ba
+ GIT_TAG 93cf5a08f421a3efd0c4a7e005ef8f742b578ce0
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp
index 195872e8edd3e..f2c1dcf69f69e 100644
--- a/csrc/cutlass_extensions/common.hpp
+++ b/csrc/cutlass_extensions/common.hpp
@@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
#endif
}
};
+
+template
+struct enable_sm120_only : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu
new file mode 100644
index 0000000000000..5515374a57599
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu
@@ -0,0 +1,23 @@
+#include "scaled_mm_kernels.hpp"
+#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
+#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
+
+namespace vllm {
+
+void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ if (out.dtype() == torch::kBFloat16) {
+ cutlass_gemm_blockwise_sm120_fp8_dispatch(
+ out, a, b, a_scales, b_scales);
+
+ } else {
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
+ cutlass_gemm_blockwise_sm120_fp8_dispatch(
+ out, a, b, a_scales, b_scales);
+ }
+}
+
+} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
new file mode 100644
index 0000000000000..d50a83ae1cd48
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
@@ -0,0 +1,183 @@
+#pragma once
+
+#include "cuda_utils.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+#include "cutlass/gemm/kernel/tile_scheduler_params.h"
+#include "cutlass/epilogue/dispatch_policy.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+
+#include "cutlass_extensions/gemm/dispatch_policy.hpp"
+#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
+
+#include "cutlass_gemm_caller.cuh"
+
+namespace vllm {
+
+using namespace cute;
+
+// clang-format off
+template
+struct cutlass_3x_gemm_fp8_blockwise {
+ using ElementAB = cutlass::float_e4m3_t;
+
+ using ElementA = ElementAB;
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type;
+ static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value;
+
+ using ElementB = ElementAB;
+ // ColumnMajor is used for B to match the CUTLASS convention.
+ using LayoutB = cutlass::layout::ColumnMajor;
+ using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type;
+ static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value;
+
+ using ElementD = OutType;
+ using LayoutD = cutlass::layout::RowMajor;
+ using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type;
+ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value;
+
+ using ElementC = void; // TODO: support bias
+ using LayoutC = LayoutD;
+ using LayoutC_Transpose = LayoutD_Transpose;
+ static constexpr int AlignmentC = AlignmentD;
+
+ using ElementAccumulator = float;
+ using ElementCompute = float;
+ using ElementBlockScale = float;
+
+ using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
+ ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
+ cute::UMMA::Major::MN, cute::UMMA::Major::K>;
+
+ // layout_SFA and layout_SFB cannot be swapped since they are deduced.
+ using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
+ using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
+
+ using ArchTag = cutlass::arch::Sm120;
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
+
+ static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
+ using ElementScalar = float;
+ using DefaultOperation = cutlass::epilogue::fusion::LinearCombination;
+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag,
+ OperatorClass,
+ MmaTileShape,
+ ClusterShape,
+ cutlass::epilogue::collective::EpilogueTileAuto,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ LayoutC,
+ AlignmentC,
+ ElementD,
+ LayoutD,
+ AlignmentD,
+ EpilogueScheduler,
+ DefaultOperation
+ >::CollectiveOp;
+
+ using StageCountType = cutlass::gemm::collective::StageCountAuto;
+ using CollectiveMainloop =
+ typename cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag,
+ OperatorClass,
+ ElementA,
+ cute::tuple,
+ AlignmentA,
+ ElementB,
+ cute::tuple,
+ AlignmentB,
+ ElementAccumulator,
+ MmaTileShape,
+ ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ MainloopScheduler
+ >::CollectiveOp;
+
+ using KernelType = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue>>;
+
+ struct GemmKernel : public KernelType {};
+};
+
+template
+void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ using GemmKernel = typename Gemm::GemmKernel;
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using LayoutSFA = typename Gemm::LayoutSFA;
+ using LayoutSFB = typename Gemm::LayoutSFB;
+ using ScaleConfig = typename Gemm::ScaleConfig;
+
+ using ElementAB = typename Gemm::ElementAB;
+ using ElementD = typename Gemm::ElementD;
+
+ int32_t m = a.size(0), n = b.size(1), k = a.size(1);
+
+ StrideA a_stride;
+ StrideB b_stride;
+ StrideC c_stride;
+ a_stride =
+ cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
+ b_stride =
+ cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
+ c_stride =
+ cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
+
+ LayoutSFA layout_SFA =
+ ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
+ LayoutSFB layout_SFB =
+ ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
+
+ auto a_ptr = static_cast(a.data_ptr());
+ auto b_ptr = static_cast(b.data_ptr());
+ auto a_scales_ptr = static_cast(a_scales.data_ptr());
+ auto b_scales_ptr = static_cast(b_scales.data_ptr());
+
+ auto mainloop_args = [&](){
+ return typename GemmKernel::MainloopArguments{
+ a_ptr, a_stride, b_ptr, b_stride,
+ a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
+ };
+ }();
+ auto prob_shape = cute::make_shape(m, n, k, 1);
+
+ auto c_ptr = static_cast(out.data_ptr());
+ typename GemmKernel::EpilogueArguments epilogue_args{
+ {}, c_ptr, c_stride, c_ptr, c_stride};
+ c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args,
+ epilogue_args);
+}
+
+template
+void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ // TODO: better heuristics
+ cutlass_gemm_caller_blockwise,
+ Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
+ cutlass::gemm::collective::KernelScheduleAuto>>(
+ out, a, b, a_scales, b_scales);
+}
+
+} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
index e049a5f2d2c9a..9ceb3a3ece5d6 100644
--- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
@@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
+
+void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales);
} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
index 0c47ab82991dc..dc87c5c35cb8e 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
@@ -1,11 +1,9 @@
-#include
+#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
-#include "cuda_utils.h"
-
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
- NVIDIA GPUs with sm120 (Blackwell Geforce).
+ NVIDIA GPUs with sm120 (Blackwell).
*/
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional const& bias) {
- TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
- TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
-
- int M = a.size(0), N = b.size(1), K = a.size(1);
- TORCH_CHECK(
- (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
- (b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
- "Currently, block scaled fp8 gemm is not implemented for Blackwell");
-
- // Standard per-tensor/per-token/per-channel scaling
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
- "Currently, only fp8 gemm is implemented for Blackwell");
- vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
+ dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
+ vllm::cutlass_scaled_mm_sm120_fp8,
+ nullptr, // int8 not supported on SM120
+ vllm::cutlass_scaled_mm_blockwise_sm120_fp8);
}
#endif
diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu
index 65cb1c1d1478d..e3a0e15f5304f 100644
--- a/csrc/rocm/attention.cu
+++ b/csrc/rocm/attention.cu
@@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const auto max_num_partitions = gridDim.y;
- const int context_len = context_lens[seq_idx];
+ const int seq_len = seq_lens[seq_idx];
const int partition_start_token_idx =
partition_idx * T_PAR_SIZE; // partition_size;
// exit if partition is out of context for seq
- if (partition_start_token_idx >= context_len) {
+ if (partition_start_token_idx >= seq_len) {
return;
}
@@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens
// across 4 rows x 4 tokens per lane
- const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
- const int last_ctx_block = num_context_blocks - 1;
+ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
+ const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
- const int kblock_idx = (kglobal_token_idx < context_len)
+ const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
@@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// tokens
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
- const int vblock_idx = (vglobal_token_idx < context_len)
+ const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
@@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
if constexpr (ALIBI_ENABLED) {
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
- const int alibi_offset = local_token_idx - context_len + 1;
+ const int alibi_offset = local_token_idx - seq_len + 1;
for (int i = 0; i < 4; i++) {
d_out[token_depth][i] += alibi_slope * (alibi_offset + i);
}
@@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 4; i++) {
- const float tmp = (local_token_idx + i < context_len)
- ? d_out[token_depth][i]
- : -FLT_MAX;
+ const float tmp =
+ (local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
@@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 4; i++) {
- const float tmp = (local_token_idx + i < context_len)
+ const float tmp = (local_token_idx + i < seq_len)
? __expf(d_out[token_depth][i] - qk_max)
: 0.0f;
d_out[token_depth][i] = tmp;
@@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const auto partition_size = blockDim.x;
const auto max_num_partitions = gridDim.y;
- const int context_len = context_lens[seq_idx];
+ const int seq_len = seq_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size;
// exit if partition is out of context for seq
- if (partition_start_token_idx >= context_len) {
+ if (partition_start_token_idx >= seq_len) {
return;
}
// every 4 lanes fetch 4 different qheads
@@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int warp_start_token_idx =
partition_start_token_idx + warpid * WARP_SIZE;
- if (warp_start_token_idx >= context_len) { // warp out of context
+ if (warp_start_token_idx >= seq_len) { // warp out of context
#pragma unroll
for (int h = 0; h < GQA_RATIO4; h++) {
shared_qk_max[warpid][h] = -FLT_MAX;
@@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
} else { // warp within context
- const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
- const int last_ctx_block = num_context_blocks - 1;
+ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
+ const int last_seq_block = num_seq_blocks - 1;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// token id within partition
@@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int global_token_idx = partition_start_token_idx + local_token_idx;
// fetch block number for k
- const int block_idx = (global_token_idx < context_len)
+ const int block_idx = (global_token_idx < seq_len)
? global_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
// fetch k physical block number
// int32 physical_block_number leads to overflow when multiplied with
@@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
- (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
+ (vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
@@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int lane4_token_idx = 4 * (global_token_idx >> 2);
if constexpr (ALIBI_ENABLED) {
- const int alibi_offset = lane4_token_idx - context_len + 1;
+ const int alibi_offset = lane4_token_idx - seq_len + 1;
for (int h = 0; h < QHLOOP; h++) {
for (int i = 0; i < 4; i++) {
d_out[h][i] += alibi_slope[h] * (alibi_offset + i);
@@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int h = 0; h < QHLOOP; h++) {
qk_max[h] = -FLT_MAX;
for (int i = 0; i < 4; i++) {
- qk_max[h] = (lane4_token_idx + i < context_len)
+ qk_max[h] = (lane4_token_idx + i < seq_len)
? fmaxf(qk_max[h], d_out[h][i])
: qk_max[h];
}
@@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int h = 0; h < QHLOOP; h++) {
exp_sum[h] = 0.0f;
for (int i = 0; i < 4; i++) {
- d_out[h][i] = (lane4_token_idx + i < context_len)
+ d_out[h][i] = (lane4_token_idx + i < seq_len)
? __expf(d_out[h][i] - qk_max[h])
: 0.0f;
exp_sum[h] += d_out[h][i];
@@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
}
- if (warp_start_token_idx >= context_len) { // warp out of context
+ if (warp_start_token_idx >= seq_len) { // warp out of context
for (int qh = 0; qh < QHLOOP; qh++) {
for (int vh = 0; vh < VHELOOP; vh++) {
vout_shared[qh][vh][laneid][warpid] = {0};
@@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
@@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return;
}
- const int context_len = context_lens[seq_idx];
- const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+ const int seq_len = seq_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const auto warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum;
@@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int max_num_partitions = gridDim.y;
- const int context_len = context_lens[seq_idx]; // length of a seq
+ const int seq_len = seq_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
- if (partition_start_token_idx >= context_len) {
+ if (partition_start_token_idx >= seq_len) {
return;
}
@@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
- const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
- const int last_ctx_block = num_context_blocks - 1;
+ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
+ const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
- const int kblock_idx = (kglobal_token_idx < context_len)
+ const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
@@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
- const int vblock_idx = (vglobal_token_idx < context_len)
+ const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
@@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
- const float tmp = (local_token_idx + 2 * i < context_len)
- ? dout[token_depth][i]
- : -FLT_MAX;
+ const float tmp =
+ (local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
@@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
- const float tmp = (local_token_idx + 2 * i < context_len)
+ const float tmp = (local_token_idx + 2 * i < seq_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
@@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
@@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return;
}
- const int context_len = context_lens[seq_idx];
- const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+ const int seq_len = seq_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum;
@@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int max_num_partitions = gridDim.y;
- const int context_len = context_lens[seq_idx]; // length of a seq
+ const int seq_len = seq_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
- if (partition_start_token_idx >= context_len) {
+ if (partition_start_token_idx >= seq_len) {
return;
}
@@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
- const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
- const int last_ctx_block = num_context_blocks - 1;
+ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
+ const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
- const int kblock_idx = (kglobal_token_idx < context_len)
+ const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
@@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
- const int vblock_idx = (vglobal_token_idx < context_len)
+ const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE
- : last_ctx_block;
+ : last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
@@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp =
- (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX;
+ (local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
@@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
- const float tmp = (local_token_idx + i < context_len)
+ const float tmp = (local_token_idx + i < seq_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
@@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
@@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return;
}
- const int context_len = context_lens[seq_idx];
- const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+ const int seq_len = seq_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum;
@@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
- const int* __restrict__ context_lens, // [num_seqs]
+ const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
@@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
- block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
+ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
@@ -3057,18 +3055,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
- block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
+ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
-#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
- paged_attention_ll4mi_reduce_kernel \
- <<>>( \
- out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
- context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
- fp8_out_scale_ptr);
+#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
+ paged_attention_ll4mi_reduce_kernel \
+ <<>>( \
+ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
+ query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr);
template & query_start_loc, int max_context_len,
+ torch::Tensor& block_tables, torch::Tensor& seq_lens,
+ const std::optional& query_start_loc, int max_seq_len,
const std::optional& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional& fp8_out_scale) {
int num_seqs = block_tables.size(0);
@@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher(
KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr();
- int* context_lens_ptr = context_lens.data_ptr();
+ int* seq_lens_ptr = seq_lens.data_ptr();
const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
@@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher(
: nullptr;
OUTT* out_ptr = reinterpret_cast(out.data_ptr());
- const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
+ const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256;
- const int max_num_partitions =
- DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+ const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
@@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
- torch::Tensor& block_tables, torch::Tensor& context_lens,
- const std::optional& query_start_loc, int max_context_len,
+ torch::Tensor& block_tables, torch::Tensor& seq_lens,
+ const std::optional& query_start_loc, int max_seq_len,
const std::optional& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
@@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi(
KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr();
- int* context_lens_ptr = context_lens.data_ptr();
+ int* seq_lens_ptr = seq_lens.data_ptr();
const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr());
@@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi(
const auto fp8_out_scale_ptr = nullptr;
OUTT* out_ptr = reinterpret_cast(out.data_ptr());
- const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
+ const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
constexpr int PARTITION_SIZE = 256;
- const int max_num_partitions =
- DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+ const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
@@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi(
paged_attention_custom_launcher( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
- num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
- max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
+ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
+ max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
- num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
- max_context_len, alibi_slopes, k_scale, v_scale); \
+ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
+ max_seq_len, alibi_slopes, k_scale, v_scale); \
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
@@ -3502,9 +3497,9 @@ void paged_attention(
int64_t num_kv_heads,
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
- torch::Tensor& context_lens, // [num_seqs]
+ torch::Tensor& seq_lens, // [num_seqs]
const std::optional& query_start_loc, // [num_seqs]
- int64_t block_size, int64_t max_context_len,
+ int64_t block_size, int64_t max_seq_len,
const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale,
diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h
index e538197dbcb04..34dcc9401aae8 100644
--- a/csrc/rocm/ops.h
+++ b/csrc/rocm/ops.h
@@ -15,8 +15,8 @@ void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
- torch::Tensor& block_tables, torch::Tensor& context_lens,
+ torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional& query_start_loc, int64_t block_size,
- int64_t max_context_len, const std::optional& alibi_slopes,
+ int64_t max_seq_len, const std::optional& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional& fp8_out_scale);
diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp
index 34575477bcc94..66bdc448da3ca 100644
--- a/csrc/rocm/torch_bindings.cpp
+++ b/csrc/rocm/torch_bindings.cpp
@@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
- " Tensor context_lens,"
+ " Tensor seq_lens,"
" Tensor? query_start_loc,"
" int block_size,"
- " int max_context_len,"
+ " int max_seq_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 04a63f5d68e65..b96d50f0a1c6d 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -210,16 +210,7 @@ ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
# Flag to control whether to use pre-built vLLM wheels
-ARG VLLM_USE_PRECOMPILED
-# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed
-ENV VLLM_USE_PRECOMPILED=""
-RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \
- export VLLM_USE_PRECOMPILED=1 && \
- echo "Using precompiled wheels"; \
- else \
- unset VLLM_USE_PRECOMPILED && \
- echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \
- fi
+ARG VLLM_USE_PRECOMPILED=""
# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \
@@ -236,6 +227,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
&& export SCCACHE_IDLE_TIMEOUT=0 \
&& export CMAKE_BUILD_TYPE=Release \
+ && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
+ && export VLLM_DOCKER_BUILD_CONTEXT=1 \
&& sccache --show-stats \
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
&& sccache --show-stats; \
@@ -249,6 +242,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
# Clean any existing CMake artifacts
rm -rf .deps && \
mkdir -p .deps && \
+ export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
+ export VLLM_DOCKER_BUILD_CONTEXT=1 && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
fi
@@ -392,7 +387,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel.
-ARG FLASHINFER_GIT_REF="v0.2.10"
+ARG FLASHINFER_GIT_REF="v0.2.11"
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
. /etc/environment
git clone --depth 1 --recursive --shallow-submodules \
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 7d5a589eb1d7d..65d2e5036b783 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -1,9 +1,12 @@
-# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually.
-FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base
+FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base
RUN rm /etc/apt/sources.list.d/intel-graphics.list
-RUN apt-get update -y && \
+RUN apt clean && apt-get update -y && \
+ apt-get install -y software-properties-common && \
+ add-apt-repository ppa:deadsnakes/ppa && \
+ apt-get install -y python3.10 python3.10-distutils && \
+ curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \
apt-get install -y --no-install-recommends --fix-missing \
curl \
ffmpeg \
@@ -14,11 +17,13 @@ RUN apt-get update -y && \
libgl1 \
lsb-release \
numactl \
- python3 \
- python3-dev \
- python3-pip \
+ python3.10-dev \
wget
+
+RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
+RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
+
WORKDIR /workspace/vllm
COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt
COPY requirements/common.txt /workspace/vllm/requirements/common.txt
diff --git a/docs/.nav.yml b/docs/.nav.yml
index ad742be3d6947..dbac0e12f1bf2 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -1,25 +1,17 @@
nav:
- - Home:
- - vLLM: README.md
+ - Home: README.md
+ - User Guide:
+ - usage/README.md
- Getting Started:
- getting_started/quickstart.md
- getting_started/installation
- Examples:
+ - examples/README.md
- Offline Inference: examples/offline_inference
- Online Serving: examples/online_serving
- Others: examples/others
- - Quick Links:
- - User Guide: usage/README.md
- - Developer Guide: contributing/README.md
- - API Reference: api/README.md
- - CLI Reference: cli/README.md
- - Timeline:
- - Roadmap: https://roadmap.vllm.ai
- - Releases: https://github.com/vllm-project/vllm/releases
- - User Guide:
- - Summary: usage/README.md
- - usage/v1_guide.md
- General:
+ - usage/v1_guide.md
- usage/*
- Inference and Serving:
- serving/offline_inference.md
@@ -32,7 +24,7 @@ nav:
- deployment/integrations
- Training: training
- Configuration:
- - Summary: configuration/README.md
+ - configuration/README.md
- configuration/*
- Models:
- models/supported_models.md
@@ -45,11 +37,11 @@ nav:
- features/*
- features/quantization
- Developer Guide:
- - Summary: contributing/README.md
+ - contributing/README.md
- General:
- glob: contributing/*
flatten_single_child_sections: true
- - Model Implementation:
+ - Model Implementation:
- contributing/model/README.md
- contributing/model/basic.md
- contributing/model/registration.md
@@ -58,12 +50,9 @@ nav:
- CI: contributing/ci
- Design Documents: design
- API Reference:
- - Summary: api/README.md
- - Contents:
- - glob: api/vllm/*
- preserve_directory_names: true
- - CLI Reference:
- - Summary: cli/README.md
+ - api/README.md
+ - api/vllm/*
+ - CLI Reference: cli
- Community:
- community/*
- Blog: https://blog.vllm.ai
diff --git a/docs/README.md b/docs/README.md
index 6823008ed3360..e8d2fd953a96d 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -21,6 +21,17 @@ vLLM is a fast and easy-to-use library for LLM inference and serving.
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
+Where to get started with vLLM depends on the type of user. If you are looking to:
+
+- Run open-source models on vLLM, we recommend starting with the [Quickstart Guide](./getting_started/quickstart.md)
+- Build applications with vLLM, we recommend starting with the [User Guide](./usage)
+- Build vLLM, we recommend starting with [Developer Guide](./contributing)
+
+For information about the development of vLLM, see:
+
+- [Roadmap](https://roadmap.vllm.ai)
+- [Releases](https://github.com/vllm-project/vllm/releases)
+
vLLM is fast with:
- State-of-the-art serving throughput
diff --git a/docs/api/README.md b/docs/api/README.md
index db4dab0ae534b..327472df1d52c 100644
--- a/docs/api/README.md
+++ b/docs/api/README.md
@@ -1,7 +1,5 @@
# Summary
-[](){ #configuration }
-
## Configuration
API documentation for vLLM's configuration classes.
diff --git a/docs/cli/.meta.yml b/docs/cli/.meta.yml
new file mode 100644
index 0000000000000..0e1f7ecceebcd
--- /dev/null
+++ b/docs/cli/.meta.yml
@@ -0,0 +1 @@
+toc_depth: 3
\ No newline at end of file
diff --git a/docs/cli/.nav.yml b/docs/cli/.nav.yml
new file mode 100644
index 0000000000000..6c2c09d566a3a
--- /dev/null
+++ b/docs/cli/.nav.yml
@@ -0,0 +1,8 @@
+nav:
+ - README.md
+ - serve.md
+ - chat.md
+ - complete.md
+ - run-batch.md
+ - vllm bench:
+ - bench/*.md
diff --git a/docs/cli/README.md b/docs/cli/README.md
index b1371c82a4c4d..c708eb7958980 100644
--- a/docs/cli/README.md
+++ b/docs/cli/README.md
@@ -1,7 +1,3 @@
----
-toc_depth: 4
----
-
# vLLM CLI Guide
The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with:
@@ -18,37 +14,46 @@ vllm {chat,complete,serve,bench,collect-env,run-batch}
## serve
-Start the vLLM OpenAI Compatible API server.
+Starts the vLLM OpenAI Compatible API server.
-??? console "Examples"
+Start with a model:
- ```bash
- # Start with a model
- vllm serve meta-llama/Llama-2-7b-hf
+```bash
+vllm serve meta-llama/Llama-2-7b-hf
+```
- # Specify the port
- vllm serve meta-llama/Llama-2-7b-hf --port 8100
+Specify the port:
- # Check with --help for more options
- # To list all groups
- vllm serve --help=listgroup
+```bash
+vllm serve meta-llama/Llama-2-7b-hf --port 8100
+```
- # To view a argument group
- vllm serve --help=ModelConfig
+Serve over a Unix domain socket:
- # To view a single argument
- vllm serve --help=max-num-seqs
+```bash
+vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock
+```
- # To search by keyword
- vllm serve --help=max
+Check with --help for more options:
- # To view full help with pager (less/more)
- vllm serve --help=page
- ```
+```bash
+# To list all groups
+vllm serve --help=listgroup
-### Options
+# To view a argument group
+vllm serve --help=ModelConfig
---8<-- "docs/argparse/serve.md"
+# To view a single argument
+vllm serve --help=max-num-seqs
+
+# To search by keyword
+vllm serve --help=max
+
+# To view full help with pager (less/more)
+vllm serve --help=page
+```
+
+See [vllm serve](./serve.md) for the full reference of all available arguments.
## chat
@@ -65,6 +70,8 @@ vllm chat --url http://{vllm-serve-host}:{vllm-serve-port}/v1
vllm chat --quick "hi"
```
+See [vllm chat](./chat.md) for the full reference of all available arguments.
+
## complete
Generate text completions based on the given prompt via the running API server.
@@ -80,7 +87,7 @@ vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1
vllm complete --quick "The future of AI is"
```
-
+See [vllm complete](./complete.md) for the full reference of all available arguments.
## bench
@@ -107,6 +114,8 @@ vllm bench latency \
--load-format dummy
```
+See [vllm bench latency](./bench/latency.md) for the full reference of all available arguments.
+
### serve
Benchmark the online serving throughput.
@@ -121,6 +130,8 @@ vllm bench serve \
--num-prompts 5
```
+See [vllm bench serve](./bench/serve.md) for the full reference of all available arguments.
+
### throughput
Benchmark offline inference throughput.
@@ -134,6 +145,8 @@ vllm bench throughput \
--load-format dummy
```
+See [vllm bench throughput](./bench/throughput.md) for the full reference of all available arguments.
+
## collect-env
Start collecting environment information.
@@ -146,24 +159,25 @@ vllm collect-env
Run batch prompts and write results to file.
-
-Examples
+Running with a local file:
```bash
-# Running with a local file
vllm run-batch \
-i offline_inference/openai_batch/openai_example_batch.jsonl \
-o results.jsonl \
--model meta-llama/Meta-Llama-3-8B-Instruct
+```
-# Using remote file
+Using remote file:
+
+```bash
vllm run-batch \
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
-o results.jsonl \
--model meta-llama/Meta-Llama-3-8B-Instruct
```
-
+See [vllm run-batch](./run-batch.md) for the full reference of all available arguments.
## More Help
diff --git a/docs/cli/bench/latency.md b/docs/cli/bench/latency.md
new file mode 100644
index 0000000000000..21ab13e63781a
--- /dev/null
+++ b/docs/cli/bench/latency.md
@@ -0,0 +1,9 @@
+# vllm bench latency
+
+## JSON CLI Arguments
+
+--8<-- "docs/cli/json_tip.inc.md"
+
+## Options
+
+--8<-- "docs/argparse/bench_latency.md"
diff --git a/docs/cli/bench/serve.md b/docs/cli/bench/serve.md
new file mode 100644
index 0000000000000..f7c415c6becb5
--- /dev/null
+++ b/docs/cli/bench/serve.md
@@ -0,0 +1,9 @@
+# vllm bench serve
+
+## JSON CLI Arguments
+
+--8<-- "docs/cli/json_tip.inc.md"
+
+## Options
+
+--8<-- "docs/argparse/bench_serve.md"
diff --git a/docs/cli/bench/throughput.md b/docs/cli/bench/throughput.md
new file mode 100644
index 0000000000000..e4ff5ce43c9ce
--- /dev/null
+++ b/docs/cli/bench/throughput.md
@@ -0,0 +1,9 @@
+# vllm bench throughput
+
+## JSON CLI Arguments
+
+--8<-- "docs/cli/json_tip.inc.md"
+
+## Options
+
+--8<-- "docs/argparse/bench_throughput.md"
diff --git a/docs/cli/chat.md b/docs/cli/chat.md
new file mode 100644
index 0000000000000..b006cb8de60d0
--- /dev/null
+++ b/docs/cli/chat.md
@@ -0,0 +1,5 @@
+# vllm chat
+
+## Options
+
+--8<-- "docs/argparse/chat.md"
diff --git a/docs/cli/complete.md b/docs/cli/complete.md
new file mode 100644
index 0000000000000..400359acf4fb8
--- /dev/null
+++ b/docs/cli/complete.md
@@ -0,0 +1,5 @@
+# vllm complete
+
+## Options
+
+--8<-- "docs/argparse/complete.md"
diff --git a/docs/cli/json_tip.inc.md b/docs/cli/json_tip.inc.md
new file mode 100644
index 0000000000000..c22430c264c19
--- /dev/null
+++ b/docs/cli/json_tip.inc.md
@@ -0,0 +1,9 @@
+When passing JSON CLI arguments, the following sets of arguments are equivalent:
+
+- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`
+- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
+
+Additionally, list elements can be passed individually using `+`:
+
+- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`
+- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`
\ No newline at end of file
diff --git a/docs/cli/run-batch.md b/docs/cli/run-batch.md
new file mode 100644
index 0000000000000..f7d401b8dad2b
--- /dev/null
+++ b/docs/cli/run-batch.md
@@ -0,0 +1,9 @@
+# vllm run-batch
+
+## JSON CLI Arguments
+
+--8<-- "docs/cli/json_tip.inc.md"
+
+## Options
+
+--8<-- "docs/argparse/run-batch.md"
diff --git a/docs/cli/serve.md b/docs/cli/serve.md
new file mode 100644
index 0000000000000..2c8f9d320f5df
--- /dev/null
+++ b/docs/cli/serve.md
@@ -0,0 +1,9 @@
+# vllm serve
+
+## JSON CLI Arguments
+
+--8<-- "docs/cli/json_tip.inc.md"
+
+## Options
+
+--8<-- "docs/argparse/serve.md"
diff --git a/docs/community/meetups.md b/docs/community/meetups.md
index e8b3a9c9c8e69..36232e6ad96cc 100644
--- a/docs/community/meetups.md
+++ b/docs/community/meetups.md
@@ -2,6 +2,7 @@
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
+- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152).
- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md
index b8a1ddbe38794..6ad3a66252664 100644
--- a/docs/community/sponsors.md
+++ b/docs/community/sponsors.md
@@ -15,6 +15,7 @@ Cash Donations:
Compute Resources:
+- Alibaba Cloud
- AMD
- Anyscale
- AWS
diff --git a/docs/configuration/engine_args.md b/docs/configuration/engine_args.md
index c3c1d5a1c362d..05d4f762306a3 100644
--- a/docs/configuration/engine_args.md
+++ b/docs/configuration/engine_args.md
@@ -11,6 +11,8 @@ Engine arguments control the behavior of the vLLM engine.
The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings.
+--8<-- "docs/cli/json_tip.inc.md"
+
## `EngineArgs`
--8<-- "docs/argparse/engine_args.md"
diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md
index a2941c80bd27c..a93435ed71b50 100644
--- a/docs/configuration/tpu.md
+++ b/docs/configuration/tpu.md
@@ -96,7 +96,7 @@ Although it’s common to do this with GPUs, don't try to fragment 2 or 8 differ
### Tune your workloads
-Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](../../benchmarks/auto_tune/README.md) to optimize your workloads for your use case.
+Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](gh-file:benchmarks/auto_tune/README.md) to optimize your workloads for your use case.
### Future Topics We'll Cover
diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md
index 3a6026d450a67..7ef22d6f8c3f5 100644
--- a/docs/contributing/ci/update_pytorch_version.md
+++ b/docs/contributing/ci/update_pytorch_version.md
@@ -131,19 +131,6 @@ MAX_JOBS=16 uv pip install --system \
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
```
-### Mamba
-
-```bash
-uv pip install --system \
- --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.5"
-```
-
-### causal-conv1d
-
-```bash
-uv pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
-```
-
## Update all the different vLLM platforms
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md
index edd9a47e132fd..21b1f21d60a35 100644
--- a/docs/contributing/model/basic.md
+++ b/docs/contributing/model/basic.md
@@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
To support a model with interleaving sliding windows, we need to take care of the following details:
-- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
+- Make sure the model's `config.json` contains `layer_types`.
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
With these two steps, interleave sliding windows should work with the model.
diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md
index 3295b8c711c0c..64a48be32645a 100644
--- a/docs/contributing/model/multimodal.md
+++ b/docs/contributing/model/multimodal.md
@@ -540,8 +540,10 @@ return a schema of the tensors outputted by the HF processor that are related to
The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore
`(1, num_images, num_patches, patch_width * patch_height * num_channels)`.
- In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA,
- we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]:
+ In order to support the use of
+ [MultiModalFieldConfig.batched][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
+ like in LLaVA, we remove the extra batch dimension by overriding
+ [BaseMultiModalProcessor._call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]:
??? code
@@ -816,7 +818,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2),
[BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3),
and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4),
-decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.processing.MultiModalRegistry.register_processor]
+decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.registry.MultiModalRegistry.register_processor]
to register them to the multi-modal registry:
```diff
diff --git a/docs/design/metrics.md b/docs/design/metrics.md
index 1f65331d3c0a9..b01838883f31e 100644
--- a/docs/design/metrics.md
+++ b/docs/design/metrics.md
@@ -57,11 +57,11 @@ In v0, the following metrics are exposed via a Prometheus-compatible `/metrics`
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
-These are documented under [Inferencing and Serving -> Production Metrics](../../usage/metrics.md).
+These are documented under [Inferencing and Serving -> Production Metrics](../usage/metrics.md).
### Grafana Dashboard
-vLLM also provides [a reference example](../../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
+vLLM also provides [a reference example](../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
@@ -455,7 +455,7 @@ In general:
[an escape hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics)
for some time before deleting them.
-See the [deprecation policy](../../contributing/deprecation_policy.md) for
+See the [deprecation policy](../contributing/deprecation_policy.md) for
the project-wide deprecation policy.
### Unimplemented - `vllm:tokens_total`
@@ -655,7 +655,7 @@ v0 has support for OpenTelemetry tracing:
- Added by
- Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces`
- [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/)
-- [User-facing docs](../../examples/online_serving/opentelemetry.md)
+- [User-facing docs](../examples/online_serving/opentelemetry.md)
- [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
- [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
diff --git a/docs/examples/README.md b/docs/examples/README.md
new file mode 100644
index 0000000000000..34e4dfd408a20
--- /dev/null
+++ b/docs/examples/README.md
@@ -0,0 +1,7 @@
+# Examples
+
+vLLM's examples are split into three categories:
+
+- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference/)
+- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving/)
+- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others/)
diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md
index d97a462f54320..13b151bc7f380 100644
--- a/docs/features/quantization/inc.md
+++ b/docs/features/quantization/inc.md
@@ -1,7 +1,4 @@
----
-title: FP8 INC
----
-[](){ #inc }
+# FP8 INC
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
Currently, quantization is validated only in Llama models.
diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md
new file mode 100644
index 0000000000000..5749b02d26f45
--- /dev/null
+++ b/docs/features/sleep_mode.md
@@ -0,0 +1,80 @@
+# Sleep Mode
+
+vLLM's Sleep Mode allows you to temporarily release most GPU memory used by a model, including model weights and KV cache, without stopping the server or unloading the Docker container. This is especially useful for RLHF, training, or cost-saving scenarios where GPU resources need to be freed between inference workloads.
+
+Key benefits:
+
+- **Frees GPU memory**: Offloads model weights to CPU RAM and discards KV cache, releasing up to 90%+ of GPU memory for other tasks.
+- **Fast resume**: Quickly wake up the engine and resume inference without full model reload.
+- **API endpoints**: Control sleep/wake_up state via HTTP endpoints or Python API.
+- **Supports distributed workloads**: Works with tensor parallelism, pipeline parallelism, etc.
+- **Fine-grained control**: Optionally wake up only model weights or KV cache to avoid OOM during weight updates.
+
+!!! note
+ This feature is only supported on CUDA platform.
+
+## Sleep levels
+
+Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update.
+
+## Usage
+
+### Offline inference
+
+Enable sleep mode by passing `enable_sleep_mode=True` to the `LLM` class.
+
+```python
+from vllm import LLM
+llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True)
+```
+
+#### Python API
+
+```python
+# Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache)
+llm.sleep(level=1)
+
+# Wake up the engine (restore weights)
+llm.wake_up()
+```
+
+#### RLHF weight updates
+
+During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations.
+
+Use `tags=["weights"]` or `tags=["kv_cache"]` to control which resources are restored, useful for RLHF and weight updates. **Note** that `is_sleeping` will report `true` until all components are awake.
+
+```python
+# Put engine to deep sleep (level=2)
+llm.sleep(level=2)
+# ... Get the new weights
+# Wake up only weights to avoid OOM
+llm.wake_up(tags=["weights"])
+# ... Update the weights
+# wake up KV cache after weights are updated
+llm.wake_up(tags=["kv_cache"])
+```
+
+### Online Serving
+
+To enable sleep mode in a vLLM server you need to initialize it with the flag `VLLM_SERVER_DEV_MODE=1` and pass `--enable-sleep-mode` to the vLLM server.
+
+#### Server in development mode
+
+When using the flag `VLLM_SERVER_DEV_MODE=1` you enable development endpoints, and these endpoints should not be exposed to users.
+
+```bash
+VLLM_SERVER_DEV_MODE=1 python -m vllm.entrypoints.openai.api_server \
+ --model Qwen/Qwen3-0.6B \
+ --enable-sleep-mode \
+ --port 8000
+```
+
+#### HTTP endpoints
+
+- `POST /sleep?level=1` — Put the model to sleep (`level=1`).
+- `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`).
+- `GET /is_sleeping` — Check if the model is sleeping.
+
+!!! note
+ These endpoints are only available when passing `VLLM_SERVER_DEV_MODE=1`.
diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md
index 89d5b489e1888..597a8e8644278 100644
--- a/docs/features/spec_decode.md
+++ b/docs/features/spec_decode.md
@@ -203,6 +203,7 @@ an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"draft_tensor_parallel_size": 1,
"num_speculative_tokens": 2,
+ "method": "eagle",
},
)
@@ -231,6 +232,9 @@ A few important things to consider when using the EAGLE based draft models:
reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under
investigation and tracked here: .
+4. When using EAGLE-3 based draft model, option "method" must be set to "eagle3".
+ That is, to specify `"method": "eagle3"` in `speculative_config`.
+
A variety of EAGLE draft models are available on the Hugging Face hub:
| Base Model | EAGLE on Hugging Face | # EAGLE Parameters |
diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md
index 49e223f9b9bf6..6dc6f94249c34 100644
--- a/docs/getting_started/installation/cpu/x86.inc.md
+++ b/docs/getting_started/installation/cpu/x86.inc.md
@@ -6,7 +6,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
# --8<-- [start:requirements]
- OS: Linux
-- CPU flags: `avx512f`, `avx512_bf16` (Optional), `avx512_vnni` (Optional)
+- CPU flags: `avx512f` (Recommended), `avx512_bf16` (Optional), `avx512_vnni` (Optional)
!!! tip
Use `lscpu` to check the CPU flags.
@@ -28,7 +28,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
[https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo)
!!! warning
- If deploying the pre-built images on machines only contain `avx512f`, `Illegal instruction` error may be raised. It is recommended to build images for these machines with `--build-arg VLLM_CPU_AVX512BF16=false` and `--build-arg VLLM_CPU_AVX512VNNI=false`.
+ If deploying the pre-built images on machines without `avx512f`, `avx512_bf16`, or `avx512_vnni` support, an `Illegal instruction` error may be raised. It is recommended to build images for these machines with the appropriate build arguments (e.g., `--build-arg VLLM_CPU_DISABLE_AVX512=true`, `--build-arg VLLM_CPU_AVX512BF16=false`, or `--build-arg VLLM_CPU_AVX512VNNI=false`) to disable unsupported features. Please note that without `avx512f`, AVX2 will be used and this version is not recommended because it only has basic feature support.
# --8<-- [end:pre-built-images]
# --8<-- [start:build-image-from-source]
@@ -37,6 +37,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
docker build -f docker/Dockerfile.cpu \
--build-arg VLLM_CPU_AVX512BF16=false (default)|true \
--build-arg VLLM_CPU_AVX512VNNI=false (default)|true \
+ --build-arg VLLM_CPU_DISABLE_AVX512=false (default)|true \
--tag vllm-cpu-env \
--target vllm-openai .
diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py
index b003b5fd6ccef..ed5d3b0092ae7 100644
--- a/docs/mkdocs/hooks/generate_argparse.py
+++ b/docs/mkdocs/hooks/generate_argparse.py
@@ -15,8 +15,14 @@ sys.modules["aiohttp"] = MagicMock()
sys.modules["blake3"] = MagicMock()
sys.modules["vllm._C"] = MagicMock()
+from vllm.benchmarks import latency # noqa: E402
+from vllm.benchmarks import serve # noqa: E402
+from vllm.benchmarks import throughput # noqa: E402
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
-from vllm.entrypoints.openai.cli_args import make_arg_parser # noqa: E402
+from vllm.entrypoints.cli.openai import ChatCommand # noqa: E402
+from vllm.entrypoints.cli.openai import CompleteCommand # noqa: E402
+from vllm.entrypoints.openai import cli_args # noqa: E402
+from vllm.entrypoints.openai import run_batch # noqa: E402
from vllm.utils import FlexibleArgumentParser # noqa: E402
logger = logging.getLogger("mkdocs")
@@ -68,7 +74,8 @@ class MarkdownFormatter(HelpFormatter):
self._markdown_output.append(
f"Possible choices: {metavar}\n\n")
- self._markdown_output.append(f"{action.help}\n\n")
+ if action.help:
+ self._markdown_output.append(f"{action.help}\n\n")
if (default := action.default) != SUPPRESS:
self._markdown_output.append(f"Default: `{default}`\n\n")
@@ -78,7 +85,7 @@ class MarkdownFormatter(HelpFormatter):
return "".join(self._markdown_output)
-def create_parser(cls, **kwargs) -> FlexibleArgumentParser:
+def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
"""Create a parser for the given class with markdown formatting.
Args:
@@ -88,18 +95,12 @@ def create_parser(cls, **kwargs) -> FlexibleArgumentParser:
Returns:
FlexibleArgumentParser: A parser with markdown formatting for the class.
"""
- parser = FlexibleArgumentParser()
+ parser = FlexibleArgumentParser(add_json_tip=False)
parser.formatter_class = MarkdownFormatter
with patch("vllm.config.DeviceConfig.__post_init__"):
- return cls.add_cli_args(parser, **kwargs)
-
-
-def create_serve_parser() -> FlexibleArgumentParser:
- """Create a parser for the serve command with markdown formatting."""
- parser = FlexibleArgumentParser()
- parser.formatter_class = lambda prog: MarkdownFormatter(
- prog, starting_heading_level=4)
- return make_arg_parser(parser)
+ _parser = add_cli_args(parser, **kwargs)
+ # add_cli_args might be in-place so return parser if _parser is None
+ return _parser or parser
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
@@ -113,10 +114,24 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# Create parsers to document
parsers = {
- "engine_args": create_parser(EngineArgs),
- "async_engine_args": create_parser(AsyncEngineArgs,
- async_args_only=True),
- "serve": create_serve_parser(),
+ "engine_args":
+ create_parser(EngineArgs.add_cli_args),
+ "async_engine_args":
+ create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True),
+ "serve":
+ create_parser(cli_args.make_arg_parser),
+ "chat":
+ create_parser(ChatCommand.add_cli_args),
+ "complete":
+ create_parser(CompleteCommand.add_cli_args),
+ "bench_latency":
+ create_parser(latency.add_cli_args),
+ "bench_throughput":
+ create_parser(throughput.add_cli_args),
+ "bench_serve":
+ create_parser(serve.add_cli_args),
+ "run-batch":
+ create_parser(run_batch.make_arg_parser),
}
# Generate documentation for each parser
diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py
index 0ee52bb34603b..6b4c5b31075f7 100644
--- a/docs/mkdocs/hooks/generate_examples.py
+++ b/docs/mkdocs/hooks/generate_examples.py
@@ -105,7 +105,7 @@ class Example:
return fix_case(self.path.stem.replace("_", " ").title())
def generate(self) -> str:
- content = f"---\ntitle: {self.title}\n---\n\n"
+ content = f"# {self.title}\n\n"
content += f"Source .\n\n"
# Use long code fence to avoid issues with
diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css
index fb44d9cdcf3d3..6a1979b241ae0 100644
--- a/docs/mkdocs/stylesheets/extra.css
+++ b/docs/mkdocs/stylesheets/extra.css
@@ -23,6 +23,13 @@ a:not(:has(svg)):not(.md-icon):not(.autorefs-external) {
}
}
+a[href*="localhost"]::after,
+a[href*="127.0.0.1"]::after,
+a[href*="org.readthedocs.build"]::after,
+a[href*="docs.vllm.ai"]::after {
+ display: none !important;
+}
+
/* Light mode: darker section titles */
body[data-md-color-scheme="default"] .md-nav__item--section > label.md-nav__link .md-ellipsis {
color: rgba(0, 0, 0, 0.7) !important;
diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md
index a3ad413593f3c..a64ecd31ebaef 100644
--- a/docs/models/generative_models.md
+++ b/docs/models/generative_models.md
@@ -4,7 +4,7 @@ vLLM provides first-class support for generative models, which covers most of LL
In vLLM, generative models implement the[VllmModelForTextGeneration][vllm.model_executor.models.VllmModelForTextGeneration] interface.
Based on the final hidden states of the input, these models output log probabilities of the tokens to generate,
-which are then passed through [Sampler][vllm.model_executor.layers.Sampler] to obtain the final text.
+which are then passed through [Sampler][vllm.model_executor.layers.sampler.Sampler] to obtain the final text.
## Configuration
@@ -19,7 +19,7 @@ Run a model in generation mode via the option `--runner generate`.
## Offline Inference
The [LLM][vllm.LLM] class provides various methods for offline inference.
-See [configuration][configuration] for a list of options when initializing the model.
+See [configuration](../api/summary.md#configuration) for a list of options when initializing the model.
### `LLM.generate`
diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md
index c6588363b63fb..39f209d0eb7ed 100644
--- a/docs/models/pooling_models.md
+++ b/docs/models/pooling_models.md
@@ -81,7 +81,7 @@ which takes priority over both the model's and Sentence Transformers's defaults.
## Offline Inference
The [LLM][vllm.LLM] class provides various methods for offline inference.
-See [configuration][configuration] for a list of options when initializing the model.
+See [configuration](../api/summary.md#configuration) for a list of options when initializing the model.
### `LLM.embed`
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 265643a441041..a24fa4bcce333 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -320,7 +320,7 @@ th {
}
-| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ |
@@ -331,7 +331,7 @@ th {
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ |
@@ -349,9 +349,10 @@ th {
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
+| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ |
@@ -404,16 +405,19 @@ th {
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | |
-| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
+| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
+| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
+Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
+
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
+|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
+| `SmolLM3ForCausalLM` | SmolLM3 | `HuggingFaceTB/SmolLM3-3B` | ✅︎ | ✅︎ | ✅︎ |
+
!!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
-!!! note
- Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
-
### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models.
@@ -426,7 +430,7 @@ See [this page](./pooling_models.md) for more information on how to use pooling
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
-| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
| `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
@@ -466,7 +470,7 @@ of the whole prompt are extracted from the normalized hidden state corresponding
These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API.
-| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
@@ -483,7 +487,7 @@ If your model is not in the above list, we will try to automatically convert the
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
-| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
@@ -521,7 +525,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API.
-| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `LlamaForCausalLM`C | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
@@ -583,6 +587,9 @@ See [this page](../features/multimodal_inputs.md) on how to pass multi-modal inp
**This is no longer required if you are using vLLM V1.**
+!!! tip
+ For hybrid-only models such as Llama-4, Step3 and Mistral-3, a text-only mode can be enabled by setting all supported multimodal modalities to 0 (e.g, `--limit-mm-per-prompt '{"image":0}`) so that their multimodal modules will not be loaded to free up more GPU memory for KV cache.
+
!!! note
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
@@ -594,20 +601,21 @@ See [this page](generative_models.md) for more information on how to use generat
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
-| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
| `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ |
| `Blip2ForConditionalGeneration` | BLIP-2 | T + IE | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ |
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
+| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
| `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
+| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `Glm4MoeForCausalLM` | GLM-4.5 | T + IE+ + VE+ | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
@@ -647,7 +655,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
-| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | ✅︎ |
@@ -674,6 +682,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
+!!! note
+ `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
+ MobileNet-v5 vision backbone.
+
+ Performance is not yet fully optimized mainly due to:
+
+ - Both audio and vision MM encoders use `transformers.AutoModel` implementation.
+ - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
+
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
@@ -726,7 +743,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
Speech2Text models trained specifically for Automatic Speech Recognition.
-| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ |
@@ -744,7 +761,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
The following table lists those that are tested in vLLM.
-| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
+| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
| `LlavaNextForConditionalGeneration`C | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | |
| `Phi3VForCausalLM`C | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | |
@@ -760,7 +777,7 @@ The following table lists those that are tested in vLLM.
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
-| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
+| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|
| `JinaVLForSequenceClassification` | JinaVL-based | T + IE+ | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ |
diff --git a/docs/serving/distributed_serving.md b/docs/serving/parallelism_scaling.md
similarity index 99%
rename from docs/serving/distributed_serving.md
rename to docs/serving/parallelism_scaling.md
index fc9d9f8a34347..fa7fc1b290d50 100644
--- a/docs/serving/distributed_serving.md
+++ b/docs/serving/parallelism_scaling.md
@@ -1,4 +1,4 @@
-# Distributed inference and serving
+# Parallelism and Scaling
## Distributed inference strategies for a single-model replica
diff --git a/docs/usage/README.md b/docs/usage/README.md
index 681db57d8e0f5..83aea121819f8 100644
--- a/docs/usage/README.md
+++ b/docs/usage/README.md
@@ -1,6 +1,8 @@
# Using vLLM
-vLLM supports the following usage patterns:
+First, vLLM must be [installed](../getting_started/installation) for your chosen device in either a Python or Docker environment.
+
+Then, vLLM supports the following usage patterns:
- [Inference and Serving](../serving/offline_inference.md): Run a single instance of a model.
- [Deployment](../deployment/docker.md): Scale up model instances for production.
diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md
index f9ba32c58c4e1..9715ad66d9b35 100644
--- a/docs/usage/troubleshooting.md
+++ b/docs/usage/troubleshooting.md
@@ -289,7 +289,7 @@ Traceback (most recent call last):
...
```
-This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Distributed Inference and Serving](../serving/distributed_serving.md#running-vllm-on-multiple-nodes) for guidance on properly configuring the environment for distributed serving.
+This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Enabling GPUDirect RDMA](../serving/parallelism_scaling.md#enabling-gpudirect-rdma) for guidance on properly configuring the environment for GPUDirect RDMA.
## Known Issues
diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md
index d30144e8a8253..54af970ea842d 100644
--- a/docs/usage/v1_guide.md
+++ b/docs/usage/v1_guide.md
@@ -59,12 +59,13 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
### Hardware
-| Hardware | Status |
-|------------|------------------------------------|
-| **NVIDIA** | 🚀 |
-| **AMD** | 🟢 |
-| **TPU** | 🟢 |
-| **CPU** | 🟢 (x86) 🟡 (MacOS) |
+| Hardware | Status |
+|------------|-----------------------------------------------|
+| **NVIDIA** | 🚀 |
+| **AMD** | 🟢 |
+| **INTEL GPU** | 🟢 |
+| **TPU** | 🟢 |
+| **CPU** | 🟢 (x86\_64/aarch64) 🟡 (MacOS) |
!!! note
@@ -72,6 +73,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
- [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
- [vllm-spyre](https://github.com/vllm-project/vllm-spyre)
+ - [vllm-gaudi](https://github.com/vllm-project/vllm-gaudi)
- [vllm-openvino](https://github.com/vllm-project/vllm-openvino)
Please check their corresponding repositories for more details.
@@ -111,6 +113,10 @@ Models that combine Mamba-2 and Mamba-1 layers with standard attention layers ar
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
+Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`).
+Please note that these models currently require disabling prefix caching, enforcing eager mode, and using the FlashInfer
+attention backend in V1.
+
#### Encoder-Decoder Models
Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`)
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index 01d6a188be994..22cb8b057dac7 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
)
+# Gemma3N
+def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
+ model_name = "google/gemma-3n-E2B-it"
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=2048,
+ max_num_batched_tokens=2048,
+ max_num_seqs=2,
+ limit_mm_per_prompt={"audio": audio_count},
+ enforce_eager=True,
+ )
+ prompt = f"user\n{question}"
+ "\nmodel\n"
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ )
+
+
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
@@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"voxtral": run_voxtral,
+ "gemma3n": run_gemma3n,
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 1314d33e90093..988ad35cdd7e6 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -126,6 +126,29 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
)
+def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "CohereLabs/command-a-vision-07-2025"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=32768,
+ tensor_parallel_size=4,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ prompts = [
+ f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><|IMG_PATCH|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# Deepseek-VL2
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -211,7 +234,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
)
for question in questions
]
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+# Gemma3N
+def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+ model_name = "google/gemma-3n-E2B-it"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=2048,
+ max_num_seqs=2,
+ limit_mm_per_prompt={modality: 1},
+ enforce_eager=True,
+ )
+
+ prompts = [
+ (
+ "user\n"
+ f"{question}\n"
+ "model\n"
+ )
+ for question in questions
+ ]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
@@ -1391,10 +1440,12 @@ model_example_map = {
"aya_vision": run_aya_vision,
"blip-2": run_blip2,
"chameleon": run_chameleon,
+ "command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2,
"florence2": run_florence2,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
+ "gemma3n": run_gemma3n,
"glm4v": run_glm4v,
"glm4_1v": run_glm4_1v,
"h2ovl_chat": run_h2ovl,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 1ab405fa14f3a..799337ed68503 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -107,6 +107,42 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
)
+def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData:
+ model_name = "CohereLabs/command-a-vision-07-2025"
+
+ # NOTE: This model is 122B parameters and requires tensor parallelism
+ # Recommended to use tp=4 on H100 GPUs
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=32768,
+ tensor_parallel_size=4,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ *placeholders,
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+
+ processor = AutoProcessor.from_pretrained(model_name)
+
+ prompt = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ image_data=[fetch_image(url) for url in image_urls],
+ )
+
+
def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny"
@@ -1031,6 +1067,7 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
"aya_vision": load_aya_vision,
+ "command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
diff --git a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh
index 1284466a45580..682df45d95d79 100644
--- a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh
+++ b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh
@@ -15,6 +15,14 @@ else
MODEL=$2
fi
+# The prefillers and decoders in LMCache use the same hash seed for all chunk keys.
+# This seed must be aligned so that decoders can identify and retrieve KV cache
+# entries stored by prefillers.
+#
+# WARNING: Using a fixed hash seed is insecure and makes the application vulnerable to
+# denial-of-service attacks. In a production environment, this should be set to a
+# secure random value. This is set to a fixed value for demonstration purposes only.
+export PYTHONHASHSEED=${VLLM_PYTHON_HASH_SEED:-123}
if [[ $1 == "prefiller" ]]; then
# Prefiller listens on port 8100
diff --git a/mkdocs.yaml b/mkdocs.yaml
index e5b7454003310..47fe1ebce9712 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -34,11 +34,13 @@ theme:
- content.action.edit
- content.code.copy
- content.tabs.link
+ - navigation.instant
+ - navigation.instant.progress
- navigation.tracking
- navigation.tabs
- navigation.tabs.sticky
- navigation.sections
- - navigation.prune
+ - navigation.indexes
- navigation.top
- search.highlight
- search.share
@@ -51,11 +53,6 @@ hooks:
- docs/mkdocs/hooks/generate_argparse.py
- docs/mkdocs/hooks/url_schemes.py
-# Required to stop api-autonav from raising an error
-# https://github.com/tlambert03/mkdocs-api-autonav/issues/16
-nav:
- - api
-
plugins:
- meta
- search
diff --git a/pyproject.toml b/pyproject.toml
index dfad5d2cdf319..03a32ac0ba3d7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,8 +73,6 @@ line-length = 80
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"]
-# Python 3.8 typing - skip utils for ROCm
-"vllm/utils/__init__.py" = ["UP006", "UP035"]
[tool.ruff.lint]
select = [
diff --git a/requirements/common.txt b/requirements/common.txt
index 5c422500e1ceb..1a8fea0dd7d93 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -12,7 +12,7 @@ tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp
-openai >= 1.98.0 # For Responses API with reasoning content
+openai >= 1.99.1 # For Responses API with reasoning content
pydantic >= 2.10
prometheus_client >= 0.18.0
pillow # Required for image processing
diff --git a/requirements/docs.txt b/requirements/docs.txt
index c589093110dad..a24b9c7e924bf 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -29,3 +29,5 @@ setproctitle
torch
transformers
zmq
+uvloop
+prometheus-client
diff --git a/requirements/test.in b/requirements/test.in
index 1e0cab80a24f3..6652bfdfe66c9 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -10,7 +10,7 @@ pytest-timeout
# testing utils
backoff # required for phi4mm test
blobfile # required for kimi-vl test
-einops # required for MPT, qwen-vl and Mamba
+einops # required for MPT, qwen-vl
httpx
librosa # required for audio tests
vector_quantize_pytorch # required for minicpmo_26 test
@@ -21,12 +21,11 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
-timm # required for internvl test
+timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.7.1
torchaudio==2.7.1
torchvision==0.22.1
transformers_stream_generator # required for qwen-vl test
-mamba_ssm==2.2.5 # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
num2words # required for smolvlm test
@@ -53,4 +52,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
-terratorch==1.1rc2 # required for PrithviMAE test
\ No newline at end of file
+terratorch==1.1rc2 # required for PrithviMAE test
diff --git a/requirements/test.txt b/requirements/test.txt
index 324f8153b2ac4..ff9886a315976 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -178,7 +178,6 @@ einops==0.8.1
# via
# -r requirements/test.in
# encodec
- # mamba-ssm
# terratorch
# torchgeo
# vector-quantize-pytorch
@@ -417,8 +416,6 @@ lxml==5.3.0
# sacrebleu
mako==1.3.10
# via alembic
-mamba-ssm==2.2.5
- # via -r requirements/test.in
markdown==3.8.2
# via mlflow
markdown-it-py==3.0.0
@@ -475,8 +472,6 @@ networkx==3.2.1
# via
# scikit-image
# torch
-ninja==1.11.1.3
- # via mamba-ssm
nltk==3.9.1
# via rouge-score
num2words==0.5.14
@@ -629,7 +624,6 @@ packaging==24.2
# lazy-loader
# lightning
# lightning-utilities
- # mamba-ssm
# matplotlib
# mlflow-skinny
# peft
@@ -973,7 +967,6 @@ sentencepiece==0.2.0
setuptools==77.0.3
# via
# lightning-utilities
- # mamba-ssm
# pytablewriter
# torch
# triton
@@ -1058,7 +1051,7 @@ tiktoken==0.7.0
# via
# lm-eval
# mistral-common
-timm==1.0.15
+timm==1.0.17
# via
# -r requirements/test.in
# open-clip-torch
@@ -1085,7 +1078,6 @@ torch==2.7.1+cu128
# lightly
# lightning
# lm-eval
- # mamba-ssm
# mteb
# open-clip-torch
# peft
@@ -1152,16 +1144,13 @@ transformers==4.55.0
# -r requirements/test.in
# genai-perf
# lm-eval
- # mamba-ssm
# peft
# sentence-transformers
# transformers-stream-generator
transformers-stream-generator==0.0.5
# via -r requirements/test.in
triton==3.3.1
- # via
- # mamba-ssm
- # torch
+ # via torch
tritonclient==2.51.0
# via
# -r requirements/test.in
diff --git a/requirements/xpu.txt b/requirements/xpu.txt
index 0d95dc57152de..4607c3efdf14c 100644
--- a/requirements/xpu.txt
+++ b/requirements/xpu.txt
@@ -10,15 +10,10 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
-
-torch==2.7.0+xpu
+--extra-index-url=https://download.pytorch.org/whl/xpu
+torch==2.8.0+xpu
torchaudio
torchvision
pytorch-triton-xpu
---extra-index-url=https://download.pytorch.org/whl/xpu
-
-# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
-# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
-intel-extension-for-pytorch==2.7.10+xpu
-oneccl_bind_pt==2.7.0+xpu
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+intel-extension-for-pytorch==2.8.10+xpu
diff --git a/setup.py b/setup.py
index e374fcb816e7f..919300e143c1e 100644
--- a/setup.py
+++ b/setup.py
@@ -7,6 +7,7 @@ import json
import logging
import os
import re
+import shutil
import subprocess
import sys
from pathlib import Path
@@ -281,10 +282,81 @@ class cmake_build_ext(build_ext):
self.copy_file(file, dst_file)
-class repackage_wheel(build_ext):
+class precompiled_build_ext(build_ext):
+ """Disables extension building when using precompiled binaries."""
+
+ def run(self) -> None:
+ assert _is_cuda(
+ ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
+
+ def build_extensions(self) -> None:
+ print("Skipping build_ext: using precompiled extensions.")
+ return
+
+
+class precompiled_wheel_utils:
"""Extracts libraries and other files from an existing wheel."""
- def get_base_commit_in_main_branch(self) -> str:
+ @staticmethod
+ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
+ import tempfile
+ import zipfile
+
+ temp_dir = None
+ try:
+ if not os.path.isfile(wheel_url_or_path):
+ wheel_filename = wheel_url_or_path.split("/")[-1]
+ temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
+ wheel_path = os.path.join(temp_dir, wheel_filename)
+ print(f"Downloading wheel from {wheel_url_or_path} "
+ f"to {wheel_path}")
+ from urllib.request import urlretrieve
+ urlretrieve(wheel_url_or_path, filename=wheel_path)
+ else:
+ wheel_path = wheel_url_or_path
+ print(f"Using existing wheel at {wheel_path}")
+
+ package_data_patch = {}
+
+ with zipfile.ZipFile(wheel_path) as wheel:
+ files_to_copy = [
+ "vllm/_C.abi3.so",
+ "vllm/_moe_C.abi3.so",
+ "vllm/_flashmla_C.abi3.so",
+ "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
+ "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
+ "vllm/cumem_allocator.abi3.so",
+ ]
+
+ compiled_regex = re.compile(
+ r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
+ file_members = list(
+ filter(lambda x: x.filename in files_to_copy,
+ wheel.filelist))
+ file_members += list(
+ filter(lambda x: compiled_regex.match(x.filename),
+ wheel.filelist))
+
+ for file in file_members:
+ print(f"[extract] {file.filename}")
+ target_path = os.path.join(".", file.filename)
+ os.makedirs(os.path.dirname(target_path), exist_ok=True)
+ with wheel.open(file.filename) as src, open(
+ target_path, "wb") as dst:
+ shutil.copyfileobj(src, dst)
+
+ pkg = os.path.dirname(file.filename).replace("/", ".")
+ package_data_patch.setdefault(pkg, []).append(
+ os.path.basename(file.filename))
+
+ return package_data_patch
+ finally:
+ if temp_dir is not None:
+ print(f"Removing temporary directory {temp_dir}")
+ shutil.rmtree(temp_dir)
+
+ @staticmethod
+ def get_base_commit_in_main_branch() -> str:
# Force to use the nightly wheel. This is mainly used for CI testing.
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
return "nightly"
@@ -297,6 +369,10 @@ class repackage_wheel(build_ext):
]).decode("utf-8")
upstream_main_commit = json.loads(resp_json)["sha"]
+ # In Docker build context, .git may be immutable or missing.
+ if envs.VLLM_DOCKER_BUILD_CONTEXT:
+ return upstream_main_commit
+
# Check if the upstream_main_commit exists in the local repo
try:
subprocess.check_output(
@@ -329,86 +405,6 @@ class repackage_wheel(build_ext):
"wheel may not be compatible with your dev branch: %s", err)
return "nightly"
- def run(self) -> None:
- assert _is_cuda(
- ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
-
- wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
- if wheel_location is None:
- base_commit = self.get_base_commit_in_main_branch()
- wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
- # Fallback to nightly wheel if latest commit wheel is unavailable,
- # in this rare case, the nightly release CI hasn't finished on main.
- if not is_url_available(wheel_location):
- wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
-
- import zipfile
-
- if os.path.isfile(wheel_location):
- wheel_path = wheel_location
- print(f"Using existing wheel={wheel_path}")
- else:
- # Download the wheel from a given URL, assume
- # the filename is the last part of the URL
- wheel_filename = wheel_location.split("/")[-1]
-
- import tempfile
-
- # create a temporary directory to store the wheel
- temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
- wheel_path = os.path.join(temp_dir, wheel_filename)
-
- print(f"Downloading wheel from {wheel_location} to {wheel_path}")
-
- from urllib.request import urlretrieve
-
- try:
- urlretrieve(wheel_location, filename=wheel_path)
- except Exception as e:
- from setuptools.errors import SetupError
-
- raise SetupError(
- f"Failed to get vLLM wheel from {wheel_location}") from e
-
- with zipfile.ZipFile(wheel_path) as wheel:
- files_to_copy = [
- "vllm/_C.abi3.so",
- "vllm/_moe_C.abi3.so",
- "vllm/_flashmla_C.abi3.so",
- "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
- "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
- "vllm/cumem_allocator.abi3.so",
- # "vllm/_version.py", # not available in nightly wheels yet
- ]
-
- file_members = list(
- filter(lambda x: x.filename in files_to_copy, wheel.filelist))
-
- # vllm_flash_attn python code:
- # Regex from
- # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
- compiled_regex = re.compile(
- r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
- file_members += list(
- filter(lambda x: compiled_regex.match(x.filename),
- wheel.filelist))
-
- for file in file_members:
- print(f"Extracting and including {file.filename} "
- "from existing wheel")
- package_name = os.path.dirname(file.filename).replace("/", ".")
- file_name = os.path.basename(file.filename)
-
- if package_name not in package_data:
- package_data[package_name] = []
-
- wheel.extract(file)
- if file_name.endswith(".py"):
- # python files shouldn't be added to package_data
- continue
-
- package_data[package_name].append(file_name)
-
def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
@@ -639,6 +635,29 @@ package_data = {
]
}
+# If using precompiled, extract and patch package_data (in advance of setup)
+if envs.VLLM_USE_PRECOMPILED:
+ assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
+ wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
+ if wheel_location is not None:
+ wheel_url = wheel_location
+ else:
+ base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
+ wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+ from urllib.request import urlopen
+ try:
+ with urlopen(wheel_url) as resp:
+ if resp.status != 200:
+ wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+ except Exception as e:
+ print(f"[warn] Falling back to nightly wheel: {e}")
+ wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+
+ patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
+ wheel_url)
+ for pkg, files in patch.items():
+ package_data.setdefault(pkg, []).extend(files)
+
if _no_device():
ext_modules = []
@@ -647,7 +666,7 @@ if not ext_modules:
else:
cmdclass = {
"build_ext":
- repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
+ precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
}
setup(
@@ -665,7 +684,7 @@ setup(
"mistral_common[audio]"], # Required for audio processing
"video": [], # Kept for backwards compatibility
# FlashInfer should be updated together with the Dockerfile
- "flashinfer": ["flashinfer-python==0.2.10"],
+ "flashinfer": ["flashinfer-python==0.2.11"],
},
cmdclass=cmdclass,
package_data=package_data,
diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py
index fae49c41d5f83..9212c04deec90 100644
--- a/tests/distributed/test_custom_all_reduce.py
+++ b/tests/distributed/test_custom_all_reduce.py
@@ -10,8 +10,7 @@ import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce)
-from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
- get_tp_group, graph_capture)
+from vllm.distributed.parallel_state import get_tp_group, graph_capture
from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment, multi_process_parallel)
@@ -37,7 +36,7 @@ def graph_allreduce(
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)
- group = get_tensor_model_parallel_group().device_group
+ group = get_tp_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py
index a4added29144e..6245ccbeca877 100644
--- a/tests/distributed/test_quick_all_reduce.py
+++ b/tests/distributed/test_quick_all_reduce.py
@@ -10,8 +10,7 @@ import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce)
-from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
- get_tp_group, graph_capture)
+from vllm.distributed.parallel_state import get_tp_group, graph_capture
from vllm.platforms import current_platform
from ..utils import (ensure_model_parallel_initialized,
@@ -42,7 +41,7 @@ def graph_quickreduce(
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)
- group = get_tensor_model_parallel_group().device_group
+ group = get_tp_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index c282bf002304a..93ac18dfcc7b4 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -93,32 +93,6 @@ class NestedConfig:
"""field"""
-@config
-@dataclass
-class FromCliConfig1:
- field: int = 1
- """field"""
-
- @classmethod
- def from_cli(cls, cli_value: str):
- inst = cls(**json.loads(cli_value))
- inst.field += 1
- return inst
-
-
-@config
-@dataclass
-class FromCliConfig2:
- field: int = 1
- """field"""
-
- @classmethod
- def from_cli(cls, cli_value: str):
- inst = cls(**json.loads(cli_value))
- inst.field += 2
- return inst
-
-
@config
@dataclass
class DummyConfig:
@@ -144,10 +118,6 @@ class DummyConfig:
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
"""Nested config"""
- from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
- """Config with from_cli method"""
- from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
- """Different config with from_cli method"""
@pytest.mark.parametrize(("type_hint", "expected"), [
@@ -199,9 +169,6 @@ def test_get_kwargs():
assert json_tip in kwargs["json_tip"]["help"]
# nested config should should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
- # from_cli configs should be constructed with the correct method
- assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
- assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(
diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py
index 39bc8ab07d45f..5d605e906e81b 100644
--- a/tests/entrypoints/llm/test_accuracy.py
+++ b/tests/entrypoints/llm/test_accuracy.py
@@ -96,9 +96,6 @@ def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
-
- # xet doesn't work well for Qwen/Qwen3-1.7B
- m.setenv("HF_HUB_DISABLE_XET", "1")
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py
index abdce8935ea58..71e76abcb7d2c 100644
--- a/tests/entrypoints/llm/test_classify.py
+++ b/tests/entrypoints/llm/test_classify.py
@@ -65,3 +65,9 @@ def test_pooling_params(llm: LLM):
assert torch.allclose(
softmax(wo_activation), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."
+
+
+def test_encode_api(llm: LLM):
+ err_msg = "pooling_task must be one of.+"
+ with pytest.raises(ValueError, match=err_msg):
+ llm.encode(prompts, use_tqdm=False)
diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py
index ab3c809054384..80261597b11a8 100644
--- a/tests/entrypoints/openai/test_async_tokenization.py
+++ b/tests/entrypoints/openai/test_async_tokenization.py
@@ -2,15 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
-import contextlib
import random
-import time
from typing import Callable
import openai
import pytest
import pytest_asyncio
-import requests
from tests.utils import RemoteOpenAIServer
@@ -87,54 +84,3 @@ async def test_with_and_without_truncate(
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
assert 500 not in responses
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
- ids=["single completion", "multiple completions", "chat"],
- argnames=["create_func_gen", "content_body"],
- argvalues=[
- (lambda x: x.completions.create, {
- "prompt": " ".join(['A'] * 300_000)
- }),
- (lambda x: x.completions.create, {
- "prompt": [" ".join(['A'] * 300_000)] * 2
- }),
- (lambda x: x.chat.completions.create, {
- "messages": [{
- "role": "user",
- "content": " ".join(['A'] * 300_000)
- }]
- }),
- ],
-)
-async def test_healthcheck_response_time(
- server: RemoteOpenAIServer,
- client: openai.AsyncOpenAI,
- create_func_gen: Callable,
- content_body: dict,
-):
- num_requests = 50
-
- create_func = create_func_gen(client)
- body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
-
- def get_response_time(url):
- start_time = time.monotonic()
- res = requests.get(url)
- end_time = time.monotonic()
- assert res.status_code == 200
- return end_time - start_time
-
- no_load_response_time = get_response_time(server.url_for("health"))
- tasks = [
- asyncio.create_task(create_func(**body)) for _ in range(num_requests)
- ]
- await asyncio.sleep(1) # give the tasks a chance to start running
- load_response_time = get_response_time(server.url_for("health"))
-
- with contextlib.suppress(openai.APIStatusError):
- await asyncio.gather(*tasks)
-
- assert load_response_time < 100 * no_load_response_time
- assert load_response_time < 0.1
diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py
index d67c05ab3e8de..2d33d3c3a6b54 100644
--- a/tests/entrypoints/openai/test_audio.py
+++ b/tests/entrypoints/openai/test_audio.py
@@ -23,6 +23,8 @@ MAXIMUM_AUDIOS = 2
@pytest.fixture(scope="module")
def server():
args = [
+ "--dtype",
+ "float32",
"--max-model-len",
"2048",
"--max-num-seqs",
diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py
index 886267c211243..30078fe90257a 100644
--- a/tests/entrypoints/openai/test_classification.py
+++ b/tests/entrypoints/openai/test_classification.py
@@ -211,3 +211,18 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
assert torch.allclose(
F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+def test_pooling(server: RemoteOpenAIServer, model_name: str):
+ # pooling api uses ALL pooling, which does not support chunked prefill.
+ response = requests.post(
+ server.url_for("pooling"),
+ json={
+ "model": model_name,
+ "input": "test",
+ "encoding_format": "float"
+ },
+ )
+ assert response.json()["error"]["type"] == "BadRequestError"
diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py
index f121693e329fa..73364294cbcdc 100644
--- a/tests/entrypoints/openai/test_rerank.py
+++ b/tests/entrypoints/openai/test_rerank.py
@@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
invocation_output["results"]):
assert rerank_result.keys() == invocations_result.keys()
assert rerank_result["relevance_score"] == pytest.approx(
- invocations_result["relevance_score"], rel=0.01)
+ invocations_result["relevance_score"], rel=0.05)
+ # TODO: reset this tolerance to 0.01 once we find
+ # an alternative to flash_attn with bfloat16
@pytest.mark.asyncio
diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py
new file mode 100644
index 0000000000000..1ca52599c519d
--- /dev/null
+++ b/tests/entrypoints/openai/test_response_api_with_harmony.py
@@ -0,0 +1,624 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+import time
+
+import pytest
+import pytest_asyncio
+import requests
+from openai import BadRequestError, NotFoundError, OpenAI
+
+from ...utils import RemoteOpenAIServer
+
+pytest.skip(allow_module_level=True, reason="gpt-oss can't run on CI yet.")
+
+MODEL_NAME = "openai/gpt-oss-20b"
+DTYPE = "bfloat16"
+
+
+@pytest.fixture(scope="module")
+def server():
+ args = ["--enforce-eager", "--tool-server", "demo"]
+
+ with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+ yield remote_server
+
+
+@pytest_asyncio.fixture
+async def client(server):
+ async with server.get_async_client() as async_client:
+ yield async_client
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_basic(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ )
+ assert response is not None
+ print("response: ", response)
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_basic_with_instructions(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ instructions="Respond in Korean.",
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_basic_with_reasoning_effort(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="What is the capital of South Korea?",
+ reasoning={"effort": "low"},
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_chat(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {
+ "role": "system",
+ "content": "Respond in Korean."
+ },
+ {
+ "role": "user",
+ "content": "Hello!"
+ },
+ {
+ "role": "assistant",
+ "content": "Hello! How can I help you today?"
+ },
+ {
+ "role": "user",
+ "content": "What is 13 * 24? Explain your answer."
+ },
+ ],
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_chat_with_input_type(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {
+ "role": "user",
+ "content": [{
+ "type": "input_text",
+ "text": "What is 13*24?"
+ }],
+ },
+ ],
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_structured_output(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input=[
+ {
+ "role": "system",
+ "content": "Extract the event information."
+ },
+ {
+ "role": "user",
+ "content":
+ "Alice and Bob are going to a science fair on Friday.",
+ },
+ ],
+ text={
+ "format": {
+ "type": "json_schema",
+ "name": "calendar_event",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "date": {
+ "type": "string"
+ },
+ "participants": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ },
+ "required": ["name", "date", "participants"],
+ "additionalProperties": False,
+ },
+ "description": "A calendar event.",
+ "strict": True,
+ }
+ },
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_structured_output_with_parse(client: OpenAI, model_name: str):
+ from pydantic import BaseModel
+
+ class CalendarEvent(BaseModel):
+ name: str
+ date: str
+ participants: list[str]
+
+ response = await client.responses.parse(
+ model=model_name,
+ input="Alice and Bob are going to a science fair on Friday",
+ instructions="Extract the event information",
+ text_format=CalendarEvent,
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_store(client: OpenAI, model_name: str):
+ for store in [True, False]:
+ response = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ store=store,
+ )
+ assert response is not None
+
+ try:
+ _retrieved_response = await client.responses.retrieve(response.id)
+ is_not_found = False
+ except NotFoundError:
+ is_not_found = True
+
+ assert is_not_found == (not store)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_background(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ background=True,
+ )
+ assert response is not None
+
+ retries = 0
+ max_retries = 30
+ while retries < max_retries:
+ response = await client.responses.retrieve(response.id)
+ if response.status == "completed":
+ break
+ time.sleep(1)
+ retries += 1
+
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_background_cancel(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="Write a long story about a cat.",
+ background=True,
+ )
+ assert response is not None
+ time.sleep(1)
+
+ cancelled_response = await client.responses.cancel(response.id)
+ assert cancelled_response is not None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_stateful_multi_turn(client: OpenAI, model_name: str):
+ response1 = await client.responses.create(
+ model=model_name,
+ input="What is 13 * 24?",
+ )
+ assert response1 is not None
+ assert response1.status == "completed"
+
+ response2 = await client.responses.create(
+ model=model_name,
+ input="What if I increase both numbers by 1?",
+ previous_response_id=response1.id,
+ )
+ assert response2 is not None
+ assert response2.status == "completed"
+
+ response3 = await client.responses.create(
+ model=model_name,
+ input="Divide the result by 2.",
+ previous_response_id=response2.id,
+ )
+ assert response3 is not None
+ assert response3.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_streaming(client: OpenAI, model_name: str):
+ prompts = [
+ "tell me a story about a cat in 20 words",
+ "What is 13 * 24? Use python to calculate the result.",
+ "When did Jensen found NVIDIA? Search it and answer the year only.",
+ ]
+
+ for prompt in prompts:
+ response = await client.responses.create(
+ model=model_name,
+ input=prompt,
+ reasoning={"effort": "low"},
+ tools=[
+ {
+ "type": "web_search_preview"
+ },
+ {
+ "type": "code_interpreter",
+ "container": {
+ "type": "auto"
+ }
+ },
+ ],
+ stream=True,
+ )
+
+ events = []
+ current_event_mode = None
+ async for event in response:
+ if current_event_mode != event.type:
+ current_event_mode = event.type
+ print(f"\n[{event.type}] ", end="", flush=True)
+
+ if "text.delta" in event.type:
+ print(event.delta, end="", flush=True)
+ elif "reasoning_text.delta" in event.type:
+ print(f"{event.delta}", end="", flush=True)
+ elif "response.code_interpreter_call_code.done" in event.type:
+ print(f"Code: {event.code}", end="", flush=True)
+ elif ("response.output_item.added" in event.type
+ and event.item.type == "web_search_call"):
+ print(f"Web search: {event.item.action}", end="", flush=True)
+ events.append(event)
+
+ assert len(events) > 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_web_search(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="Who is the president of South Korea as of now?",
+ tools=[{
+ "type": "web_search_preview"
+ }],
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_code_interpreter(client: OpenAI, model_name: str):
+ response = await client.responses.create(
+ model=model_name,
+ input="Multiply 64548*15151 using builtin python interpreter.",
+ tools=[{
+ "type": "code_interpreter",
+ "container": {
+ "type": "auto"
+ }
+ }],
+ )
+ assert response is not None
+ assert response.status == "completed"
+
+
+def get_weather(latitude, longitude):
+ response = requests.get(
+ f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m" # noqa
+ )
+ data = response.json()
+ return data["current"]["temperature_2m"]
+
+
+def get_place_to_travel():
+ return "Paris"
+
+
+def call_function(name, args):
+ if name == "get_weather":
+ return get_weather(**args)
+ elif name == "get_place_to_travel":
+ return get_place_to_travel()
+ else:
+ raise ValueError(f"Unknown function: {name}")
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_function_calling(client: OpenAI, model_name: str):
+ tools = [{
+ "type": "function",
+ "name": "get_weather",
+ "description":
+ "Get current temperature for provided coordinates in celsius.", # noqa
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "latitude": {
+ "type": "number"
+ },
+ "longitude": {
+ "type": "number"
+ },
+ },
+ "required": ["latitude", "longitude"],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ }]
+
+ response = await client.responses.create(
+ model=model_name,
+ input="What's the weather like in Paris today?",
+ tools=tools,
+ )
+ assert response is not None
+ assert response.status == "completed"
+ assert len(response.output) == 2
+ assert response.output[0].type == "reasoning"
+ assert response.output[1].type == "function_call"
+
+ tool_call = response.output[1]
+ name = tool_call.name
+ args = json.loads(tool_call.arguments)
+
+ result = call_function(name, args)
+
+ response_2 = await client.responses.create(
+ model=model_name,
+ input=[{
+ "type": "function_call_output",
+ "call_id": tool_call.call_id,
+ "output": str(result),
+ }],
+ tools=tools,
+ previous_response_id=response.id,
+ )
+ assert response_2 is not None
+ assert response_2.status == "completed"
+ assert response_2.output_text is not None
+
+ # NOTE: chain-of-thought should be removed.
+ response_3 = await client.responses.create(
+ model=model_name,
+ input="What's the weather like in Paris today?",
+ tools=tools,
+ previous_response_id=response_2.id,
+ )
+ assert response_3 is not None
+ assert response_3.status == "completed"
+ assert response_3.output_text is not None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
+ tools = [
+ {
+ "type": "function",
+ "name": "get_place_to_travel",
+ "description": "Get a random place to travel",
+ "parameters": {
+ "type": "object",
+ "properties": {},
+ "required": [],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ },
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description":
+ "Get current temperature for provided coordinates in celsius.", # noqa
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "latitude": {
+ "type": "number"
+ },
+ "longitude": {
+ "type": "number"
+ },
+ },
+ "required": ["latitude", "longitude"],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ },
+ ]
+
+ response = await client.responses.create(
+ model=model_name,
+ input=
+ "Help me plan a trip to a random place. And tell me the weather there.",
+ tools=tools,
+ )
+ assert response is not None
+ assert response.status == "completed"
+ assert len(response.output) == 2
+ assert response.output[0].type == "reasoning"
+ assert response.output[1].type == "function_call"
+
+ tool_call = response.output[1]
+ name = tool_call.name
+ args = json.loads(tool_call.arguments)
+
+ result = call_function(name, args)
+
+ response_2 = await client.responses.create(
+ model=model_name,
+ input=[{
+ "type": "function_call_output",
+ "call_id": tool_call.call_id,
+ "output": str(result),
+ }],
+ tools=tools,
+ previous_response_id=response.id,
+ )
+ assert response_2 is not None
+ assert response_2.status == "completed"
+ assert len(response_2.output) == 2
+ assert response_2.output[0].type == "reasoning"
+ assert response_2.output[1].type == "function_call"
+
+ tool_call = response_2.output[1]
+ name = tool_call.name
+ args = json.loads(tool_call.arguments)
+
+ result = call_function(name, args)
+
+ response_3 = await client.responses.create(
+ model=model_name,
+ input=[{
+ "type": "function_call_output",
+ "call_id": tool_call.call_id,
+ "output": str(result),
+ }],
+ tools=tools,
+ previous_response_id=response_2.id,
+ )
+ assert response_3 is not None
+ assert response_3.status == "completed"
+ assert response_3.output_text is not None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_function_calling_required(client: OpenAI, model_name: str):
+ tools = [{
+ "type": "function",
+ "name": "get_weather",
+ "description":
+ "Get current temperature for provided coordinates in celsius.", # noqa
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "latitude": {
+ "type": "number"
+ },
+ "longitude": {
+ "type": "number"
+ },
+ },
+ "required": ["latitude", "longitude"],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ }]
+
+ with pytest.raises(BadRequestError):
+ await client.responses.create(
+ model=model_name,
+ input="What's the weather like in Paris today?",
+ tools=tools,
+ tool_choice="required",
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_function_calling_full_history(client: OpenAI, model_name: str):
+ tools = [{
+ "type": "function",
+ "name": "get_weather",
+ "description":
+ "Get current temperature for provided coordinates in celsius.", # noqa
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "latitude": {
+ "type": "number"
+ },
+ "longitude": {
+ "type": "number"
+ },
+ },
+ "required": ["latitude", "longitude"],
+ "additionalProperties": False,
+ },
+ "strict": True,
+ }]
+
+ input_messages = [{
+ "role": "user",
+ "content": "What's the weather like in Paris today?"
+ }]
+
+ response = await client.responses.create(
+ model=model_name,
+ input=input_messages,
+ tools=tools,
+ )
+
+ assert response is not None
+ assert response.status == "completed"
+
+ tool_call = response.output[-1]
+ name = tool_call.name
+ args = json.loads(tool_call.arguments)
+
+ result = call_function(name, args)
+
+ input_messages.extend(
+ response.output) # append model's function call message
+ input_messages.append(
+ { # append result message
+ "type": "function_call_output",
+ "call_id": tool_call.call_id,
+ "output": str(result),
+ }
+ )
+
+ response_2 = await client.responses.create(
+ model=model_name,
+ input=input_messages,
+ tools=tools,
+ )
+ assert response_2 is not None
+ assert response_2.status == "completed"
+ assert response_2.output_text is not None
diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py
index 1a5df1d2dbd2d..cb6ec795ae969 100644
--- a/tests/entrypoints/openai/test_score.py
+++ b/tests/entrypoints/openai/test_score.py
@@ -220,7 +220,9 @@ class TestModel:
invocation_output["data"]):
assert score_data.keys() == invocation_data.keys()
assert score_data["score"] == pytest.approx(
- invocation_data["score"], rel=0.01)
+ invocation_data["score"], rel=0.05)
+ # TODO: reset this tolerance to 0.01 once we find
+ # an alternative to flash_attn with bfloat16
def test_activation(self, server: RemoteOpenAIServer, model: dict[str,
Any]):
diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py
index 4bf3798503656..058e96f203c38 100644
--- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py
+++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py
@@ -44,7 +44,7 @@ def model_uri(tmp_dir):
def tensorize_model_and_lora(tmp_dir, model_uri):
tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri,
lora_dir=tmp_dir)
- args = EngineArgs(model=MODEL_NAME, device="cuda")
+ args = EngineArgs(model=MODEL_NAME)
tensorize_lora_adapter(LORA_PATH, tensorizer_config)
tensorize_vllm_model(args, tensorizer_config)
diff --git a/tests/entrypoints/openai/test_uds.py b/tests/entrypoints/openai/test_uds.py
new file mode 100644
index 0000000000000..5c39869a794fd
--- /dev/null
+++ b/tests/entrypoints/openai/test_uds.py
@@ -0,0 +1,43 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from tempfile import TemporaryDirectory
+
+import httpx
+import pytest
+
+from vllm.version import __version__ as VLLM_VERSION
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+
+
+@pytest.fixture(scope="module")
+def server():
+ with TemporaryDirectory() as tmpdir:
+ args = [
+ # use half precision for speed and memory savings in CI environment
+ "--dtype",
+ "bfloat16",
+ "--max-model-len",
+ "8192",
+ "--enforce-eager",
+ "--max-num-seqs",
+ "128",
+ "--uds",
+ f"{tmpdir}/vllm.sock",
+ ]
+
+ with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+ yield remote_server
+
+
+@pytest.mark.asyncio
+async def test_show_version(server: RemoteOpenAIServer):
+ transport = httpx.HTTPTransport(uds=server.uds)
+ client = httpx.Client(transport=transport)
+ response = client.get(server.url_for("version"))
+ response.raise_for_status()
+
+ assert response.json() == {"version": VLLM_VERSION}
diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py
new file mode 100644
index 0000000000000..3f2f330f6dc3b
--- /dev/null
+++ b/tests/kernels/core/test_mrope.py
@@ -0,0 +1,215 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+import torch
+from transformers import AutoConfig
+
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.platforms import current_platform
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
+ head_size: int, max_position_embeddings: int,
+ dtype: torch.dtype, device: torch.device):
+ """Generate test data for given configuration."""
+ # Create 2D positions (3, num_tokens) for multimodal case
+ positions = torch.randint(0,
+ max_position_embeddings // 4, (3, num_tokens),
+ device=device)
+
+ # Create query and key tensors
+ query = torch.randn(num_tokens,
+ num_q_heads * head_size,
+ dtype=dtype,
+ device=device)
+ key = torch.randn(num_tokens,
+ num_kv_heads * head_size,
+ dtype=dtype,
+ device=device)
+
+ return positions, query, key
+
+
+def unroll_model_tp_dict(model_tp_dict):
+ return [(model_name, tp_size)
+ for model_name, tp_sizes in model_tp_dict.items()
+ for tp_size in tp_sizes]
+
+
+model_tp_dict = {
+ "Qwen/Qwen2-VL-7B-Instruct": [1, 2],
+ "Qwen/Qwen2-VL-72B-Instruct": [1, 2],
+ "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
+ "zai-org/GLM-4.1V-9B-Thinking": [1, 2],
+}
+
+# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
+dtype_atol_rtol_list = [
+ [torch.bfloat16, 1e-2, 1.6e-2],
+]
+
+num_tokens_list = [11, 8192]
+
+
+@pytest.mark.skipif(not current_platform.is_cuda_alike(),
+ reason="Skipping CUDA/ROCm only tests.")
+@pytest.mark.parametrize("model_name, tp_size",
+ unroll_model_tp_dict(model_tp_dict))
+@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
+@pytest.mark.parametrize("num_tokens", num_tokens_list)
+def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
+
+ config = AutoConfig.from_pretrained(model_name)
+
+ # get the model config
+ total_num_kv_heads = config.num_key_value_heads
+ total_num_heads = config.num_attention_heads
+ num_heads = total_num_heads // tp_size
+ num_kv_heads = max(1, total_num_kv_heads // tp_size)
+ head_dim = config.hidden_size // total_num_heads
+ is_neox_style = True
+
+ rope_theta = config.rope_theta
+ max_position = config.max_position_embeddings
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ rotary_dim = int(head_dim * partial_rotary_factor)
+
+ mrope_helper_class = get_rope(
+ head_size=head_dim,
+ rotary_dim=rotary_dim,
+ max_position=max_position,
+ base=rope_theta,
+ is_neox_style=is_neox_style,
+ rope_scaling=config.rope_scaling,
+ dtype=dtype,
+ ).to(device=device)
+
+ # create q k v input tensors
+ # create rotary pos emb input tensors
+ positions, query, key = generate_test_data(num_tokens, num_heads,
+ num_kv_heads, head_dim,
+ max_position, dtype, device)
+
+ query_native, key_native = mrope_helper_class.forward_native(
+ positions,
+ query.clone(),
+ key.clone(),
+ )
+
+ query_cuda, key_cuda = mrope_helper_class.forward_cuda(
+ positions,
+ query.clone(),
+ key.clone(),
+ )
+
+ torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
+ torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
+
+
+@pytest.mark.skipif(not current_platform.is_cuda_alike(),
+ reason="Skipping CUDA/ROCm only tests.")
+@pytest.mark.parametrize(
+ "model_name, tp_size",
+ unroll_model_tp_dict({
+ "Qwen/Qwen2-VL-7B-Instruct": [1, 2],
+ "zai-org/GLM-4.1V-9B-Thinking": [1, 2]
+ }))
+@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
+@pytest.mark.parametrize("num_tokens", [4])
+def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
+ num_tokens):
+ config = AutoConfig.from_pretrained(model_name)
+
+ # get the model config
+ total_num_kv_heads = config.num_key_value_heads
+ total_num_heads = config.num_attention_heads
+ num_heads = total_num_heads // tp_size
+ num_kv_heads = max(1, total_num_kv_heads // tp_size)
+ head_dim = config.hidden_size // total_num_heads
+ is_neox_style = True
+ rope_theta = config.rope_theta
+ max_position = config.max_position_embeddings
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ rotary_dim = int(head_dim * partial_rotary_factor)
+
+ mrope_helper_class = get_rope(
+ head_size=head_dim,
+ rotary_dim=rotary_dim,
+ max_position=max_position,
+ base=rope_theta,
+ is_neox_style=is_neox_style,
+ rope_scaling=config.rope_scaling,
+ dtype=dtype,
+ ).to(device=device)
+
+ # Generate test data
+ positions, query, key = generate_test_data(num_tokens, num_heads,
+ num_kv_heads, head_dim,
+ max_position, dtype, device)
+
+ # Create a wrapper that makes the in-place function appear functional
+ def functional_forward_cuda(pos, q, k):
+ """Wrapper that converts in-place operation to functional style
+
+ CUDA Graph does not support in-place operations.
+ This wrapper creates working copies of the
+ input tensors and modifies them.
+ """
+ q_work = q.clone() # Create working copies
+ k_work = k.clone()
+ # Your in-place function modifies q_work and k_work
+ mrope_helper_class.forward_cuda(pos, q_work, k_work)
+ return q_work, k_work # Return the modified tensors
+
+ # Get reference results
+ query_native, key_native = mrope_helper_class.forward_native(
+ positions,
+ query.clone(),
+ key.clone(),
+ )
+
+ try:
+ compiled_forward_cuda = torch.compile(functional_forward_cuda,
+ fullgraph=True,
+ backend="inductor",
+ mode="reduce-overhead",
+ dynamic=False)
+
+ # Run compiled version
+ query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
+ positions,
+ query,
+ key,
+ )
+
+ # Run original version for comparison
+ query_cuda = query.clone()
+ key_cuda = key.clone()
+ mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)
+
+ # Verify results
+ torch.testing.assert_close(query_compiled_cuda,
+ query_cuda,
+ atol=atol,
+ rtol=rtol)
+ torch.testing.assert_close(key_compiled_cuda,
+ key_cuda,
+ atol=atol,
+ rtol=rtol)
+ torch.testing.assert_close(query_compiled_cuda,
+ query_native,
+ atol=atol,
+ rtol=rtol)
+ torch.testing.assert_close(key_compiled_cuda,
+ key_native,
+ atol=atol,
+ rtol=rtol)
+
+ print("✓ forward_cuda successfully traced with torch.compile inductor")
+
+ except Exception as e:
+ pytest.fail(
+ f"forward_cuda failed to trace with torch.compile inductor: {e}")
diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py
index 67b14a7faa89f..d2b893ffff7c3 100644
--- a/tests/kernels/mamba/test_mamba_ssm_ssd.py
+++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py
@@ -187,7 +187,7 @@ def generate_continuous_batched_examples(example_lens_by_batch,
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
-@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
+@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):
@@ -253,15 +253,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths
- # odd chunk_size
- (64, 29, 2, [(11, 4), (13, 23), (19, 22),
- (21, 15)]), # irregular sizes
-
# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
+
+ # we also need to test some large seqlen
+ # to catch errors with init states decay
+ (768, 128, 2, [(138, 225), (138, 225)]),
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
@@ -271,10 +271,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
- # TODO: the irregular chunk size cases have some issues and require higher
- # tolerance. This is to be invesigated
- if chunk_size not in {8, 256}:
- atol, rtol = 5e-1, 5e-1
+ # This test can have larger error for longer sequences
+ if seqlen > 256:
+ atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py
index 1f8d21a7a7027..459b785e6504e 100644
--- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py
+++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py
@@ -36,7 +36,6 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
import tempfile
temp_file = tempfile.mkstemp()[1]
- set_current_vllm_config(vllm_config)
with set_current_vllm_config(vllm_config):
init_distributed_environment(
world_size=world_size,
diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py
index 7dc6282326b66..75b2e9f791789 100644
--- a/tests/kernels/moe/test_block_fp8.py
+++ b/tests/kernels/moe/test_block_fp8.py
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
-from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
+from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
dg_available = has_deep_gemm()
@@ -224,7 +224,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
-@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
+@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
+ reason="Not E8M0 scale MOE")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py
index 266f1161a684b..9b064db973ddf 100644
--- a/tests/kernels/moe/test_deepep_deepgemm_moe.py
+++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
-from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
+from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported)
from .parallel_utils import ProcessGroupInfo, parallel_launch
@@ -370,7 +370,7 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
-@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
+@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
@@ -427,7 +427,7 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
-@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
+@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py
new file mode 100644
index 0000000000000..54f2351bf6d9b
--- /dev/null
+++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py
@@ -0,0 +1,453 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from dataclasses import dataclass, fields
+
+import pytest
+import torch
+import torch.nn.functional as F
+
+from vllm.utils import has_triton_kernels
+
+if not has_triton_kernels():
+ pytest.skip(
+ "triton_kernels not found, skipping all related tests",
+ allow_module_level=True,
+ )
+
+import triton_kernels.swiglu
+from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
+from triton_kernels.numerics import InFlexData
+from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
+ upcast_from_mxfp)
+from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
+from triton_kernels.tensor_details import layout
+from triton_kernels.testing import assert_close
+
+from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
+ BatchedPrepareAndFinalize)
+from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
+from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
+ BatchedOAITritonExperts, triton_kernel_moe_forward)
+from vllm.model_executor.layers.fused_moe.modular_kernel import (
+ FusedMoEModularKernel)
+from vllm.model_executor.layers.utils import shuffle_weight
+from vllm.utils import round_up
+
+
+def deshuffle(w: torch.Tensor):
+ first = w[..., ::2]
+ second = w[..., 1::2]
+
+ deshuffled = torch.concat((first, second), dim=-1)
+ return deshuffled
+
+
+def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
+ randbits = [torch.randperm(E) for _ in range(M)]
+ x_list = [
+ (-1)**i *
+ ((16384 +
+ ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
+ for i, bits in enumerate(randbits)
+ ]
+ exp_data = torch.stack(x_list).to(
+ device="cuda") # simulating gate_output (M, E)
+
+ # create input tensor
+ x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
+ w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
+ w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")
+
+ w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
+ w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")
+
+ exp_data_tri = exp_data.clone()
+ x_tri = x.clone()
+ w1_tri = w1.clone()
+ w2_tri = w2.clone()
+
+ w1_bias_tri = w1_bias.clone()
+ w2_bias_tri = w2_bias.clone()
+ w1_bias_tri = w1_bias_tri.to(torch.float32)
+ w2_bias_tri = w2_bias_tri.to(torch.float32)
+
+ dtype_dict = {
+ "bf16": torch.bfloat16,
+ "fp8_e4m3": torch.float8_e4m3fn,
+ "fp8_e5m2": torch.float8_e5m2,
+ }
+
+ x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
+ if w_dtype != "mx4":
+ # simulate quantization support on reference impl
+ w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
+ w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)
+
+ # triton moe kernel use transposed shape for matmul
+ w1_tri = w1_tri.transpose(-2, -1)
+ w2_tri = w2_tri.transpose(-2, -1)
+
+ # shuffle weights
+ w1_tri = shuffle_weight(w1_tri)
+ w1_bias_tri = shuffle_weight(w1_bias_tri)
+
+ # quant triton_weights
+ x_tri = x.to(dtype_dict[a_dtype])
+ if w_dtype != "mx4":
+ pytest.skip("NYI")
+ else: # quantize to mx4
+ # careful on the padding here, the activation padding need to be
+ # multiple of 64, the actual engine is not implemented
+ w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
+ w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
+
+ w2_bottom_pad = w1_right_pad // 2
+ w2_right_pad = w1_bottom_pad
+
+ x_pad = w1_bottom_pad
+
+ w1_tri = F.pad(
+ w1_tri,
+ (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
+ mode="constant",
+ value=0,
+ )
+ w2_tri = F.pad(
+ w2_tri,
+ (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
+ mode="constant",
+ value=0,
+ )
+
+ w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
+ mode="constant",
+ value=0)
+ w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
+ mode="constant",
+ value=0)
+
+ x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
+
+ w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
+ mx_axis=1)
+ w_scale_layout, w_scale_layout_opts = (
+ layout.make_default_matmul_mxfp4_w_scale_layout(
+ mx_axis=1, num_warps=num_warps))
+
+ w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
+ w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
+
+ w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
+ w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
+
+ w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
+ **w_layout_opts)
+ w1_scale_tri = convert_layout(
+ wrap_torch_tensor(w1_scale_tri),
+ w_scale_layout,
+ **w_scale_layout_opts,
+ )
+
+ w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
+ **w_layout_opts)
+ w2_scale_tri = convert_layout(
+ wrap_torch_tensor(w2_scale_tri),
+ w_scale_layout,
+ **w_scale_layout_opts,
+ )
+
+ pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
+ flex_ctx=FlexCtx(rhs_data=InFlexData()))
+ pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
+ flex_ctx=FlexCtx(rhs_data=InFlexData()))
+
+ # tucuate so the rest can run properly
+ w1 = w1[..., :K, :2 * N]
+ w2 = w2[..., :N, :K]
+
+ w1 = deshuffle(w1)
+
+ w1 = w1.transpose(-1, -2).contiguous()
+ w2 = w2.transpose(-1, -2).contiguous()
+
+ return (
+ x,
+ w1,
+ w1_bias,
+ w2,
+ w2_bias,
+ exp_data,
+ x_tri,
+ w1_tri,
+ w2_tri,
+ exp_data_tri,
+ w1_bias_tri,
+ w2_bias_tri,
+ pc1,
+ pc2,
+ )
+
+
+@dataclass
+class ModelConfig:
+ num_hidden_layers: int = 36
+ num_experts: int = 128
+ experts_per_token: int = 4
+ vocab_size: int = 201088
+ hidden_size: int = 2880
+ intermediate_size: int = 2880
+ head_dim: int = 64
+ num_attention_heads: int = 64
+ num_key_value_heads: int = 8
+ sliding_window: int = 128
+ initial_context_length: int = 4096
+ rope_theta: float = 150000.0
+ rope_scaling_factor: float = 32.0
+ rope_ntk_alpha: float = 1.0
+ rope_ntk_beta: float = 32.0
+
+
+def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
+ # Note we add an extra bias of 1 to the linear layer
+ x_glu, x_linear = torch.chunk(x, 2, dim=-1)
+ if limit is not None:
+ x_glu = x_glu.clamp(max=limit)
+ out_glu = x_glu * torch.sigmoid(alpha * x_glu)
+ if limit is not None:
+ x_linear = x_linear.clamp(min=-limit, max=limit)
+ return out_glu * (x_linear + 1)
+
+
+def oai_moe_forward(
+ hidden_states: torch.Tensor, # (M, K)
+ w1: torch.Tensor, # (E, 2N)
+ w1_bias: torch.Tensor, # (E, 2N, K)
+ w2: torch.Tensor, # (E, K, N)
+ w2_bias: torch.Tensor, # (E, N)
+ gating_output: torch.Tensor, # (M, E)
+ topk: int,
+):
+ # model.py 309:330, assuming gating and norm
+ t = hidden_states
+ experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
+ expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
+ expert_indices = experts.indices
+
+ # MLP #1
+ mlp1_weight = w1[expert_indices, ...]
+ mlp1_bias = w1_bias[expert_indices, ...]
+ t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
+ t = swiglu(t, limit=7)
+
+ # MLP #2
+ mlp2_weight = w2[expert_indices, ...]
+ mlp2_bias = w2_bias[expert_indices, ...]
+ t = torch.einsum("beck,bek->bec", mlp2_weight, t)
+ t += mlp2_bias
+
+ # Weighted sum of experts
+ t = torch.einsum("bec,be->bc", t, expert_weights)
+
+ return t
+
+
+@dataclass
+class Case:
+ a_dtype: str
+ w_dtype: str
+
+
+@pytest.mark.parametrize(
+ ", ".join(f.name for f in fields(Case)),
+ [
+ tuple(getattr(case, f.name) for f in fields(Case)) for case in [
+ # Case(a_dtype="bf16", w_dtype="bf16"),
+ # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
+ Case(a_dtype="bf16", w_dtype="mx4")
+ ]
+ ],
+)
+@pytest.mark.parametrize("num_token", [2])
+@pytest.mark.parametrize("tp", [1, 2, 4, 8])
+def test_equiv(num_token, a_dtype, w_dtype, tp):
+ M = num_token
+ E = ModelConfig.num_experts
+ K = ModelConfig.hidden_size
+ N = ModelConfig.intermediate_size // tp
+ topk = ModelConfig.experts_per_token
+
+ (
+ x,
+ w1,
+ w1_bias,
+ w2,
+ w2_bias,
+ exp_data,
+ x_tri,
+ w1_tri,
+ w2_tri,
+ exp_data_tri,
+ w1_bias_tri,
+ w2_bias_tri,
+ pc1,
+ pc2,
+ ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
+
+ out_triton_monolithic = triton_kernel_moe_forward(
+ hidden_states=x_tri,
+ w1=w1_tri,
+ w2=w2_tri,
+ gating_output=exp_data_tri,
+ topk=topk,
+ renormalize=True,
+ w1_bias=w1_bias_tri,
+ w2_bias=w2_bias_tri,
+ w1_precision=pc1,
+ w2_precision=pc2,
+ )
+ out_triton_monolithic = out_triton_monolithic[..., :K]
+
+ out_ref = oai_moe_forward(
+ hidden_states=x,
+ w1=w1,
+ w1_bias=w1_bias,
+ w2=w2,
+ w2_bias=w2_bias,
+ gating_output=exp_data,
+ topk=topk,
+ )
+ assert_close(ref=out_ref,
+ tri=out_triton_monolithic,
+ maxtol=0.025,
+ rmstol=0.005)
+
+
+def batched_moe(
+ a: torch.Tensor,
+ w1,
+ w2,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+ w1_bias: torch.Tensor,
+ w2_bias: torch.Tensor,
+ w1_precision: PrecisionConfig,
+ w2_precision: PrecisionConfig,
+) -> torch.Tensor:
+ max_num_tokens = round_up(a.shape[0], 64)
+
+ fused_experts = FusedMoEModularKernel(
+ BatchedPrepareAndFinalize(
+ max_num_tokens,
+ num_dispatchers=1,
+ num_local_experts=w1.shape[0],
+ rank=0,
+ ),
+ BatchedOAITritonExperts(
+ None,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=1,
+ w1_precision=w1_precision,
+ w2_precision=w2_precision,
+ ),
+ )
+
+ extra_expert_args = {
+ "w1_bias": w1_bias,
+ "w2_bias": w2_bias,
+ }
+
+ topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
+
+ return fused_experts(
+ a,
+ w1,
+ w2,
+ topk_weight,
+ topk_ids,
+ extra_expert_args=extra_expert_args,
+ )
+
+
+@pytest.mark.parametrize(
+ ", ".join(f.name for f in fields(Case)),
+ [
+ tuple(getattr(case, f.name) for f in fields(Case)) for case in [
+ # Case(a_dtype="bf16", w_dtype="bf16"),
+ # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
+ Case(a_dtype="bf16", w_dtype="mx4")
+ ]
+ ],
+)
+@pytest.mark.parametrize("num_token", [64])
+@pytest.mark.parametrize("ep", [1, 2, 4, 8])
+def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
+ M = num_token
+ E = ModelConfig.num_experts // ep
+ K = ModelConfig.hidden_size
+ N = ModelConfig.intermediate_size
+ topk = ModelConfig.experts_per_token
+
+ (
+ x,
+ w1,
+ w1_bias,
+ w2,
+ w2_bias,
+ exp_data,
+ x_tri,
+ w1_tri,
+ w2_tri,
+ exp_data_tri,
+ w1_bias_tri,
+ w2_bias_tri,
+ pc1,
+ pc2,
+ ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4)
+
+ out_tri = batched_moe(
+ a=x_tri,
+ w1=w1_tri,
+ w2=w2_tri,
+ gating_output=exp_data_tri,
+ topk=topk,
+ renormalize=True,
+ w1_bias=w1_bias_tri,
+ w2_bias=w2_bias_tri,
+ w1_precision=pc1,
+ w2_precision=pc2,
+ )
+ out_tri = out_tri[..., :K]
+
+ out_ref = oai_moe_forward(
+ hidden_states=x,
+ w1=w1,
+ w1_bias=w1_bias,
+ w2=w2,
+ w2_bias=w2_bias,
+ gating_output=exp_data,
+ topk=topk,
+ )
+ assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
+
+
+def test_unit_shuffle():
+ N = ModelConfig.intermediate_size
+ K = ModelConfig.hidden_size
+ m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")
+
+ x = torch.randn(K, dtype=torch.bfloat16, device="cuda")
+
+ m_shuffled = shuffle_weight(m)
+
+ out_ref = x @ m
+ out_ref = swiglu(out_ref, limit=1.0)
+
+ out = x @ m_shuffled
+ out = triton_kernels.swiglu.swiglu_torch(
+ out,
+ alpha=1.702,
+ precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0),
+ )
+
+ assert_close(ref=out_ref, tri=out)
diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py
index 67ba2f25593df..76f6c226bab7c 100644
--- a/tests/models/language/generation/test_hybrid.py
+++ b/tests/models/language/generation/test_hybrid.py
@@ -20,19 +20,15 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
- "mistralai/Mamba-Codestral-7B-v0.1",
+ "yujiepan/mamba2-codestral-v0.1-tiny-random",
]
HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
- # NOTE: Running Plamo2 in transformers implementation requires to install
- # causal-conv1d package, which is not listed as a test dependency as it's
- # not compatible with pip-compile.
- "pfnet/plamo-2-1b",
+ # skipping until vLLM implementation issues are resolved
+ # "pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
- "ibm-ai-platform/Bamba-9B-v1",
- "nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
]
@@ -42,23 +38,18 @@ HF_UNSUPPORTED_MODELS = [
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
- "mistralai/Mamba-Codestral-7B-v0.1",
- # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
- # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
- "nvidia/Nemotron-H-8B-Base-8K",
- # NOTE: Currently the test fails due to HF transformers issue fixed in:
- # https://github.com/huggingface/transformers/pull/39033
- # We will enable vLLM test for Granite after next HF transformers release.
- "ibm-granite/granite-4.0-tiny-preview",
+ "yujiepan/mamba2-codestral-v0.1-tiny-random",
+ # transformers 4.55 is still producing garbage for this model
+ # TODO(tdoublep): follow-up on transformers side
+ "ibm-granite/granite-4.0-tiny-preview"
]
V1_SUPPORTED_MODELS = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
- "mistralai/Mamba-Codestral-7B-v0.1",
- "ibm-ai-platform/Bamba-9B-v1",
+ "yujiepan/mamba2-codestral-v0.1-tiny-random",
"Zyphra/Zamba2-1.2B-instruct",
- "nvidia/Nemotron-H-8B-Base-8K",
+ "hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
]
@@ -83,12 +74,16 @@ def test_models(
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
- model_info.check_transformers_version(on_fail="skip")
+ hf_version_check = model_info.check_transformers_version(
+ on_fail="return")
except ValueError:
- pass
+ hf_version_check = None
+
+ if hf_version_check is not None:
+ print(f"Skipping transformers comparison because: {hf_version_check}")
with hf_runner(model) as hf_model:
- if model not in HF_UNSUPPORTED_MODELS:
+ if model not in HF_UNSUPPORTED_MODELS and hf_version_check is None:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
@@ -389,3 +384,63 @@ def test_distributed_correctness(
name_0="vllm_tp_1",
name_1="vllm_tp_2",
)
+
+
+@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
+@pytest.mark.parametrize("max_tokens", [64])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_full_cuda_graph(
+ hf_runner,
+ vllm_runner,
+ example_prompts,
+ monkeypatch,
+ model: str,
+ max_tokens: int,
+ num_logprobs: int,
+) -> None:
+
+ try:
+ model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
+ model_info.check_available_online(on_fail="skip")
+ model_info.check_transformers_version(on_fail="skip")
+ except ValueError:
+ pass
+
+ with hf_runner(model) as hf_model:
+ if model not in HF_UNSUPPORTED_MODELS:
+ hf_outputs = hf_model.generate_greedy_logprobs_limit(
+ example_prompts, max_tokens, num_logprobs)
+ else:
+ hf_outputs = None
+
+ with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
+ vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
+ example_prompts, max_tokens, num_logprobs)
+
+ with monkeypatch.context() as m:
+ m.setenv("VLLM_USE_V1", "1")
+ if model in HYBRID_MODELS:
+ # required due to reorder_batch behaviour
+ m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
+ with vllm_runner(model,
+ max_num_seqs=MAX_NUM_SEQS,
+ compilation_config={'full_cuda_graph': True},
+ enable_prefix_caching=False) as vllm_model:
+ vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
+ example_prompts, max_tokens, num_logprobs)
+
+ if hf_outputs is not None:
+ check_logprobs_close(
+ outputs_0_lst=hf_outputs,
+ outputs_1_lst=vllm_v0_outputs,
+ name_0="hf",
+ name_1="vllm-v0",
+ )
+
+ ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
+ check_logprobs_close(
+ outputs_0_lst=ref_outputs,
+ outputs_1_lst=vllm_v1_outputs,
+ name_0="hf" if hf_outputs is not None else "vllm-v0",
+ name_1="vllm-v1",
+ )
diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py
index 8c93bbdc98c02..d024c76dddfdd 100644
--- a/tests/models/language/pooling/mteb_utils.py
+++ b/tests/models/language/pooling/mteb_utils.py
@@ -162,7 +162,8 @@ def mteb_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None,
- hf_model_callback=None):
+ hf_model_callback=None,
+ atol=MTEB_RERANK_TOL):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
@@ -176,9 +177,12 @@ def mteb_test_embed_models(hf_runner,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
+ model_config = vllm_model.llm.llm_engine.model_config
+
if model_info.architecture:
- assert (model_info.architecture
- in vllm_model.llm.llm_engine.model_config.architectures)
+ assert model_info.architecture in model_config.architectures
+ assert (model_config._model_info.default_pooling_type ==
+ model_info.default_pooling_type)
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
@@ -198,7 +202,7 @@ def mteb_test_embed_models(hf_runner,
print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)
- assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
+ assert st_main_score == pytest.approx(vllm_main_score, abs=atol)
def run_mteb_rerank(cross_encoder, tasks, languages):
@@ -285,7 +289,12 @@ def mteb_test_rerank_models(hf_runner,
**vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
+
+ if model_info.architecture:
+ assert (model_info.architecture in model_config.architectures)
assert model_config.hf_config.num_labels == 1
+ assert (model_config._model_info.default_pooling_type ==
+ model_info.default_pooling_type)
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py
new file mode 100644
index 0000000000000..15e24c59d1dd9
--- /dev/null
+++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py
@@ -0,0 +1,93 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+import torch
+from transformers import AutoModelForSequenceClassification
+
+from tests.models.language.pooling.embed_utils import (
+ run_embedding_correctness_test)
+
+
+@pytest.mark.parametrize(
+ "model",
+ ["jason9693/Qwen2.5-1.5B-apeach"],
+)
+@pytest.mark.parametrize("dtype", ["half"])
+def test_classify_models(
+ hf_runner,
+ vllm_runner,
+ example_prompts,
+ model: str,
+ dtype: str,
+) -> None:
+
+ example_prompts = example_prompts * 2
+
+ with vllm_runner(model,
+ max_model_len=512,
+ dtype=dtype,
+ enable_prefix_caching=True) as vllm_model:
+ cache_config = vllm_model.llm.llm_engine.cache_config
+ assert cache_config.enable_prefix_caching
+ vllm_outputs = vllm_model.classify(example_prompts)
+
+ with hf_runner(model,
+ dtype=dtype,
+ auto_cls=AutoModelForSequenceClassification) as hf_model:
+ hf_outputs = hf_model.classify(example_prompts)
+
+ for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
+ hf_output = torch.tensor(hf_output)
+ vllm_output = torch.tensor(vllm_output)
+
+ assert torch.allclose(hf_output, vllm_output,
+ 1e-3 if dtype == "float" else 1e-2)
+
+
+@pytest.mark.parametrize(
+ "model",
+ ["Qwen/Qwen3-Embedding-0.6B"],
+)
+@pytest.mark.parametrize("dtype", ["half"])
+def test_embed_models(
+ hf_runner,
+ vllm_runner,
+ example_prompts,
+ model: str,
+ dtype: str,
+):
+ example_prompts = [str(s).strip() for s in example_prompts] * 2
+
+ with vllm_runner(
+ model,
+ runner="pooling",
+ max_model_len=None,
+ enable_prefix_caching=True,
+ ) as vllm_model:
+ cache_config = vllm_model.llm.llm_engine.cache_config
+ assert cache_config.enable_prefix_caching
+ vllm_outputs = vllm_model.embed(example_prompts)
+
+ with hf_runner(
+ model,
+ is_sentence_transformer=True,
+ ) as hf_model:
+ run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
+
+
+@pytest.mark.parametrize(
+ "model",
+ [
+ "intfloat/e5-small",
+ "Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False
+ "papluca/xlm-roberta-base-language-detection",
+ ])
+@pytest.mark.parametrize("dtype", ["half"])
+def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str,
+ dtype: str) -> None:
+ with vllm_runner(model,
+ max_model_len=512,
+ dtype=dtype,
+ enable_prefix_caching=True) as vllm_model:
+ cache_config = vllm_model.llm.llm_engine.cache_config
+ assert not cache_config.enable_prefix_caching
diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py
index 64a8f25220dab..6fbe0e82d7f8a 100644
--- a/tests/models/language/pooling/test_baai.py
+++ b/tests/models/language/pooling/test_baai.py
@@ -2,73 +2,78 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
-from ...utils import EmbedModelInfo, RerankModelInfo
+from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
+ EmbedModelInfo, LASTPoolingEmbedModelInfo,
+ RerankModelInfo)
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
MODELS = [
########## BertModel
- EmbedModelInfo("BAAI/bge-base-en",
- architecture="BertModel",
- enable_test=True),
- EmbedModelInfo("BAAI/bge-base-zh",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-small-en",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-small-zh",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-large-en",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-large-zh",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-base-en-v1.5",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-base-zh-v1.5",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-small-en-v1.5",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-small-zh-v1.5",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-large-en-v1.5",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("BAAI/bge-large-zh-v1.5",
- architecture="BertModel",
- enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-base-en",
+ architecture="BertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("BAAI/bge-base-zh",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-small-en",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-small-zh",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-large-en",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-large-zh",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5",
+ architecture="BertModel",
+ enable_test=False),
########## XLMRobertaModel
- EmbedModelInfo("BAAI/bge-m3",
- architecture="XLMRobertaModel",
- enable_test=True),
+ CLSPoolingEmbedModelInfo("BAAI/bge-m3",
+ architecture="XLMRobertaModel",
+ enable_test=True),
########## Qwen2Model
- EmbedModelInfo("BAAI/bge-code-v1",
- architecture="Qwen2Model",
- dtype="float32",
- enable_test=True),
+ LASTPoolingEmbedModelInfo("BAAI/bge-code-v1",
+ architecture="Qwen2Model",
+ dtype="float32",
+ enable_test=True),
]
RERANK_MODELS = [
########## XLMRobertaForSequenceClassification
- RerankModelInfo("BAAI/bge-reranker-base",
- architecture="XLMRobertaForSequenceClassification",
- enable_test=True),
- RerankModelInfo("BAAI/bge-reranker-large",
- architecture="XLMRobertaForSequenceClassification",
- enable_test=False),
- RerankModelInfo("BAAI/bge-reranker-v2-m3",
- architecture="XLMRobertaForSequenceClassification",
- enable_test=False)
+ CLSPoolingRerankModelInfo(
+ "BAAI/bge-reranker-base",
+ architecture="XLMRobertaForSequenceClassification",
+ enable_test=True),
+ CLSPoolingRerankModelInfo(
+ "BAAI/bge-reranker-large",
+ architecture="XLMRobertaForSequenceClassification",
+ enable_test=False),
+ CLSPoolingRerankModelInfo(
+ "BAAI/bge-reranker-v2-m3",
+ architecture="XLMRobertaForSequenceClassification",
+ enable_test=False)
]
diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py
index 7fa9485dbc7f7..206524d7caad3 100644
--- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py
+++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py
@@ -8,12 +8,12 @@ import torch
from tests.conftest import HfRunner
-from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
- mteb_test_rerank_models)
+from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
+from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [
- RerankModelInfo("BAAI/bge-reranker-v2-gemma",
- architecture="GemmaForSequenceClassification"),
+ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
+ architecture="GemmaForSequenceClassification"),
]
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py
index 9a33063d7b469..8c1bc5779b8a1 100644
--- a/tests/models/language/pooling/test_cross_encoder.py
+++ b/tests/models/language/pooling/test_cross_encoder.py
@@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
-from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
+from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo,
+ RerankModelInfo)
+from .mteb_utils import mteb_test_rerank_models
RERANK_MODELS = [
- RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
- architecture="BertForSequenceClassification"),
- RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
- architecture="Qwen3ForSequenceClassification")
+ CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
+ architecture="BertForSequenceClassification"),
+ LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
+ architecture="Qwen3ForSequenceClassification")
]
diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py
index 48a0cd64fec12..f805a64103c06 100644
--- a/tests/models/language/pooling/test_gte.py
+++ b/tests/models/language/pooling/test_gte.py
@@ -4,57 +4,67 @@ from typing import Any
import pytest
-from ...utils import check_transformers_version
-from .embed_utils import EmbedModelInfo, correctness_test_embed_models
-from .mteb_utils import mteb_test_embed_models
+from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
+ EmbedModelInfo, LASTPoolingEmbedModelInfo,
+ RerankModelInfo, check_transformers_version)
+from .embed_utils import correctness_test_embed_models
+from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
MODELS = [
########## BertModel
- EmbedModelInfo("thenlper/gte-large",
- architecture="BertModel",
- enable_test=True),
- EmbedModelInfo("thenlper/gte-base",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("thenlper/gte-small",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("thenlper/gte-large-zh",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("thenlper/gte-base-zh",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("thenlper/gte-small-zh",
- architecture="BertModel",
- enable_test=False),
+ CLSPoolingEmbedModelInfo("thenlper/gte-large",
+ architecture="BertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("thenlper/gte-base",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("thenlper/gte-small",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("thenlper/gte-large-zh",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("thenlper/gte-base-zh",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("thenlper/gte-small-zh",
+ architecture="BertModel",
+ enable_test=False),
########### NewModel
- EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
- architecture="GteNewModel",
- enable_test=True),
- EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
- architecture="GteNewModel",
- enable_test=True),
- EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
- architecture="GteNewModel",
- enable_test=True),
+ CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
+ architecture="GteNewModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
+ architecture="GteNewModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
+ architecture="GteNewModel",
+ enable_test=True),
########### Qwen2ForCausalLM
- EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
- architecture="Qwen2ForCausalLM",
- enable_test=True),
+ LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
+ architecture="Qwen2ForCausalLM",
+ enable_test=True),
########## ModernBertModel
- EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
- architecture="ModernBertModel",
- enable_test=True),
+ CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
+ architecture="ModernBertModel",
+ enable_test=True),
########## Qwen3ForCausalLM
- EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
- architecture="Qwen3ForCausalLM",
- dtype="float32",
- enable_test=True),
- EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
- architecture="Qwen3ForCausalLM",
- dtype="float32",
- enable_test=False),
+ LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
+ architecture="Qwen3ForCausalLM",
+ dtype="float32",
+ enable_test=True),
+ LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B",
+ architecture="Qwen3ForCausalLM",
+ dtype="float32",
+ enable_test=False),
+]
+
+RERANK_MODELS = [
+ # classifier_pooling: mean
+ CLSPoolingRerankModelInfo(
+ "Alibaba-NLP/gte-reranker-modernbert-base",
+ architecture="ModernBertForSequenceClassification",
+ enable_test=True),
]
@@ -87,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs)
+
+
+@pytest.mark.parametrize("model_info", RERANK_MODELS)
+def test_rerank_models_mteb(hf_runner, vllm_runner,
+ model_info: RerankModelInfo) -> None:
+ mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling/test_intfloat.py
index d899aaada2623..e48bdbe940be7 100644
--- a/tests/models/language/pooling/test_intfloat.py
+++ b/tests/models/language/pooling/test_intfloat.py
@@ -2,34 +2,34 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
-from ...utils import EmbedModelInfo
+from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
########## BertModel
- EmbedModelInfo("intfloat/e5-small",
- architecture="BertModel",
- enable_test=True),
- EmbedModelInfo("intfloat/e5-base",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("intfloat/e5-large",
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("intfloat/multilingual-e5-small",
- architecture="BertModel",
- enable_test=False),
+ CLSPoolingEmbedModelInfo("intfloat/e5-small",
+ architecture="BertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("intfloat/e5-base",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("intfloat/e5-large",
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small",
+ architecture="BertModel",
+ enable_test=False),
########## XLMRobertaModel
- EmbedModelInfo("intfloat/multilingual-e5-base",
- architecture="XLMRobertaModel",
- enable_test=True),
- EmbedModelInfo("intfloat/multilingual-e5-large",
- architecture="XLMRobertaModel",
- enable_test=False),
- EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
- architecture="XLMRobertaModel",
- enable_test=False),
+ CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base",
+ architecture="XLMRobertaModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large",
+ architecture="XLMRobertaModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct",
+ architecture="XLMRobertaModel",
+ enable_test=False),
]
diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py
index 59b634428ceff..37c5bdc97dd98 100644
--- a/tests/models/language/pooling/test_jina.py
+++ b/tests/models/language/pooling/test_jina.py
@@ -6,20 +6,22 @@ import pytest
from vllm import PoolingParams
-from ...utils import EmbedModelInfo, RerankModelInfo
+from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
+ EmbedModelInfo, RerankModelInfo)
from .embed_utils import (check_embeddings_close,
correctness_test_embed_models, matryoshka_fy)
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
EMBEDDING_MODELS = [
- EmbedModelInfo("jinaai/jina-embeddings-v3",
- architecture="XLMRobertaModel",
- is_matryoshka=True)
+ CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3",
+ architecture="XLMRobertaModel",
+ is_matryoshka=True)
]
RERANK_MODELS = [
- RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual",
- architecture="XLMRobertaForSequenceClassification")
+ CLSPoolingRerankModelInfo(
+ "jinaai/jina-reranker-v2-base-multilingual",
+ architecture="XLMRobertaForSequenceClassification")
]
diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py
index e74c58744dd2e..480bd5e4567cb 100644
--- a/tests/models/language/pooling/test_mxbai_rerank.py
+++ b/tests/models/language/pooling/test_mxbai_rerank.py
@@ -7,15 +7,16 @@ import torch
from tests.conftest import HfRunner
-from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
+from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
+from .mteb_utils import mteb_test_rerank_models
RERANK_MODELS = [
- RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
- architecture="Qwen2ForSequenceClassification",
- enable_test=True),
- RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
- architecture="Qwen2ForSequenceClassification",
- enable_test=False)
+ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
+ architecture="Qwen2ForSequenceClassification",
+ enable_test=True),
+ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
+ architecture="Qwen2ForSequenceClassification",
+ enable_test=False)
]
diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py
index e16ec239a3381..2d05958e9bcda 100644
--- a/tests/models/language/pooling/test_nomic.py
+++ b/tests/models/language/pooling/test_nomic.py
@@ -3,22 +3,23 @@
import pytest
-from .embed_utils import EmbedModelInfo, correctness_test_embed_models
+from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
+from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
- EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
- architecture="NomicBertModel",
- enable_test=True),
- EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
- architecture="NomicBertModel",
- enable_test=False),
- EmbedModelInfo("nomic-ai/CodeRankEmbed",
- architecture="NomicBertModel",
- enable_test=False),
- EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
- architecture="NomicBertModel",
- enable_test=True)
+ CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1",
+ architecture="NomicBertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
+ architecture="NomicBertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed",
+ architecture="NomicBertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
+ architecture="NomicBertModel",
+ enable_test=True)
]
diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py
index 68e96f32700ca..37f5566a330d0 100644
--- a/tests/models/language/pooling/test_qwen3_reranker.py
+++ b/tests/models/language/pooling/test_qwen3_reranker.py
@@ -8,15 +8,16 @@ import torch
from tests.conftest import HfRunner
from tests.utils import multi_gpu_test
-from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
+from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
+from .mteb_utils import mteb_test_rerank_models
RERANK_MODELS = [
- RerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
- architecture="Qwen3ForSequenceClassification",
- enable_test=True),
- RerankModelInfo("Qwen/Qwen3-Reranker-4B",
- architecture="Qwen3ForSequenceClassification",
- enable_test=False)
+ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
+ architecture="Qwen3ForSequenceClassification",
+ enable_test=True),
+ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
+ architecture="Qwen3ForSequenceClassification",
+ enable_test=False)
]
diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py
index ef9d5530cde15..6b5ff70681459 100644
--- a/tests/models/language/pooling/test_scoring.py
+++ b/tests/models/language/pooling/test_scoring.py
@@ -23,6 +23,15 @@ TEXTS_2 = [
"The capital of Germany is Berlin.",
]
+
+@pytest.fixture(autouse=True)
+def v1(run_with_both_engines):
+ # Simple autouse wrapper to run both engines for each test
+ # This can be promoted up to conftest.py to run for every
+ # test in a package
+ pass
+
+
DTYPE = "half"
diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py
index d6b5dbd08372e..585fa0e683da2 100644
--- a/tests/models/language/pooling/test_snowflake_arctic_embed.py
+++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py
@@ -3,42 +3,43 @@
import pytest
-from .embed_utils import EmbedModelInfo, correctness_test_embed_models
+from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
+from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
- is_matryoshka=False,
- architecture="BertModel",
- enable_test=True),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
- is_matryoshka=False,
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
- is_matryoshka=False,
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
- is_matryoshka=False,
- architecture="NomicBertModel",
- enable_test=True),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
- is_matryoshka=False,
- architecture="BertModel",
- enable_test=False),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
- is_matryoshka=True,
- architecture="BertModel",
- enable_test=True),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
- is_matryoshka=True,
- architecture="XLMRobertaModel",
- enable_test=True),
- EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
- is_matryoshka=True,
- architecture="GteModel",
- enable_test=True),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
+ is_matryoshka=False,
+ architecture="BertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
+ is_matryoshka=False,
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
+ is_matryoshka=False,
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
+ is_matryoshka=False,
+ architecture="NomicBertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
+ is_matryoshka=False,
+ architecture="BertModel",
+ enable_test=False),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
+ is_matryoshka=True,
+ architecture="BertModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
+ is_matryoshka=True,
+ architecture="XLMRobertaModel",
+ enable_test=True),
+ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
+ is_matryoshka=True,
+ architecture="GteModel",
+ enable_test=True),
]
diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py
index 2bb01e494d436..b413c4d6b3667 100644
--- a/tests/models/multimodal/generation/test_mllama.py
+++ b/tests/models/multimodal/generation/test_mllama.py
@@ -6,6 +6,7 @@ from typing import Optional, overload
import pytest
import torch
from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer
+from transformers import __version__ as TRANSFORMERS_VERSION
from vllm import LLM, SamplingParams
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
@@ -285,6 +286,10 @@ def clear_cache():
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
+@pytest.mark.skipif(
+ TRANSFORMERS_VERSION == "4.55.0",
+ reason="Transformers v4.55.0 has a regression issue on mllama, "
+ "see: https://github.com/huggingface/transformers/pull/40083")
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs,
@@ -313,6 +318,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
+@pytest.mark.skipif(
+ TRANSFORMERS_VERSION == "4.55.0",
+ reason="Transformers v4.55.0 has a regression issue on mllama, "
+ "see: https://github.com/huggingface/transformers/pull/40083")
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
@@ -362,6 +371,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
+@pytest.mark.skipif(
+ TRANSFORMERS_VERSION == "4.55.0",
+ reason="Transformers v4.55.0 has a regression issue on mllama, "
+ "see: https://github.com/huggingface/transformers/pull/40083")
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
@@ -402,6 +415,10 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
+@pytest.mark.skipif(
+ TRANSFORMERS_VERSION == "4.55.0",
+ reason="Transformers v4.55.0 has a regression issue on mllama, "
+ "see: https://github.com/huggingface/transformers/pull/40083")
def test_models_distributed(
hf_runner,
vllm_runner,
diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py
index e157d6f4a79df..d39cf706786e2 100644
--- a/tests/models/multimodal/generation/test_pixtral.py
+++ b/tests/models/multimodal/generation/test_pixtral.py
@@ -18,7 +18,7 @@ from vllm.multimodal.inputs import PlaceholderRange
from vllm.sequence import Logprob, SampleLogprobs
from ....utils import VLLM_PATH, large_gpu_test
-from ...utils import check_logprobs_close
+from ...utils import check_logprobs_close, dummy_hf_overrides
if TYPE_CHECKING:
from _typeshed import StrPath
@@ -29,10 +29,10 @@ MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]
IMG_URLS = [
- "https://picsum.photos/id/237/400/300",
- "https://picsum.photos/id/231/200/300",
- "https://picsum.photos/id/27/500/500",
- "https://picsum.photos/id/17/150/600",
+ "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
+ "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/231-200x300.jpg",
+ "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/27-500x500.jpg",
+ "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/17-150x600.jpg",
]
PROMPT = "Describe each image in one short sentence."
@@ -110,11 +110,6 @@ MSGS = [
_create_msg_format(IMG_URLS[:2]),
_create_msg_format(IMG_URLS),
]
-ENGINE_INPUTS = [
- _create_engine_inputs(IMG_URLS[:1]),
- _create_engine_inputs(IMG_URLS[:2]),
- _create_engine_inputs(IMG_URLS),
-]
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4)
@@ -195,7 +190,6 @@ def test_chat(
name_1="output")
-@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("prompt,expected_ranges",
[(_create_engine_inputs_hf(IMG_URLS[:1]),
[PlaceholderRange(offset=11, length=494)]),
@@ -204,7 +198,7 @@ def test_chat(
PlaceholderRange(offset=277, length=1056),
PlaceholderRange(offset=1333, length=418)
])])
-def test_multi_modal_placeholders(vllm_runner, prompt,
+def test_multi_modal_placeholders(vllm_runner, prompt: TextPrompt,
expected_ranges: list[PlaceholderRange],
monkeypatch) -> None:
@@ -215,6 +209,8 @@ def test_multi_modal_placeholders(vllm_runner, prompt,
"mistral-community/pixtral-12b",
max_model_len=8192,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
+ load_format="dummy",
+ hf_overrides=dummy_hf_overrides,
) as vllm_model:
outputs = vllm_model.llm.generate(prompt)
@@ -230,5 +226,7 @@ def test_multi_modal_placeholders(vllm_runner, prompt,
expected_ranges), f"{image_placeholder_ranges=}"
for real_range, expected_range in zip(image_placeholder_ranges,
expected_ranges):
- assert real_range == expected_range, \
+ assert real_range.offset == expected_range.offset, \
+ f"{real_range=} {expected_range=}"
+ assert real_range.length == expected_range.length, \
f"{real_range=} {expected_range=}"
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index bd1c55d95dac2..906966ddd0649 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -271,6 +271,7 @@ def _test_processing_correctness_one(
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
+ "google/gemma-3n-E2B-it",
"zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b",
@@ -315,7 +316,7 @@ def _test_processing_correctness_one(
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"omni-research/Tarsier-7b",
- "omni-research/Tarsier2-Recap-7b"
+ "omni-research/Tarsier2-Recap-7b",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@@ -327,6 +328,8 @@ def test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
+ if model_id == "google/gemma-3n-E2B-it":
+ pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.")
_test_processing_correctness(
model_id,
hit_rate=hit_rate,
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 2c2d094e048fb..d7d20d1f3abf7 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -79,17 +79,17 @@ class _HfExamplesInfo:
def check_transformers_version(
self,
*,
- on_fail: Literal["error", "skip"],
+ on_fail: Literal["error", "skip", "return"],
check_min_version: bool = True,
check_max_version: bool = True,
- ) -> None:
+ ) -> Optional[str]:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if (self.min_transformers_version is None
and self.max_transformers_version is None):
- return
+ return None
current_version = TRANSFORMERS_VERSION
cur_base_version = Version(current_version).base_version
@@ -105,16 +105,18 @@ class _HfExamplesInfo:
and Version(cur_base_version) > Version(max_version)):
msg += f"<={max_version}` is required to run this model."
else:
- return
+ return None
if self.transformers_version_reason:
msg += f" Reason: {self.transformers_version_reason}"
if on_fail == "error":
raise RuntimeError(msg)
- else:
+ elif on_fail == "skip":
pytest.skip(msg)
+ return msg
+
def check_available_online(
self,
*,
@@ -148,7 +150,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True),
- "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B",
+ "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
+ min_transformers_version="4.55.1",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
@@ -183,7 +186,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
- "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
+ "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
@@ -197,7 +200,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
{"6b": "EleutherAI/gpt-j-6b"}),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
{"1b": "EleutherAI/pythia-1.4b"}),
- "GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b"),
+ "GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501
@@ -223,6 +226,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
+ min_transformers_version="4.55.1",
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
@@ -278,6 +282,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
+ max_transformers_version="4.53",
+ transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501
@@ -285,6 +291,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
+ "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
@@ -377,6 +384,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
+ "Cohere2VisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/command-a-vision-07-2025"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
@@ -385,12 +393,14 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
+ "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
+ min_transformers_version="4.53"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
- "Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
+ "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
@@ -517,6 +527,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
+ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
+ # "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
+ # trust_remote_code=True,
+ # speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
+ # tokenizer="Qwen/Qwen3-8B"),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,
diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py
index f0aa91566b57a..f06b34285eaea 100644
--- a/tests/models/test_initialization.py
+++ b/tests/models/test_initialization.py
@@ -68,6 +68,11 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
if model_arch == "Phi4FlashForCausalLM":
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
+ if model_arch == "GptOssForCausalLM":
+ # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
+ # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
+ # L4 supports FA3.
+ m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
diff --git a/tests/models/utils.py b/tests/models/utils.py
index 1e3d51aeec64e..84aeb927c5fa9 100644
--- a/tests/models/utils.py
+++ b/tests/models/utils.py
@@ -345,19 +345,38 @@ class EmbedModelInfo(NamedTuple):
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
dtype: str = "auto"
+ default_pooling_type: str = ""
enable_test: bool = True
+class CLSPoolingEmbedModelInfo(EmbedModelInfo):
+ default_pooling_type: str = "CLS"
+
+
+class LASTPoolingEmbedModelInfo(EmbedModelInfo):
+ default_pooling_type: str = "LAST"
+
+
class RerankModelInfo(NamedTuple):
name: str
architecture: str = ""
dtype: str = "auto"
+ default_pooling_type: str = ""
enable_test: bool = True
+class CLSPoolingRerankModelInfo(RerankModelInfo):
+ default_pooling_type: str = "CLS"
+
+
+class LASTPoolingRerankModelInfo(RerankModelInfo):
+ default_pooling_type: str = "LAST"
+
+
def dummy_hf_overrides(
hf_config: PretrainedConfig,
- model_arch: str,
+ *,
+ model_arch: str = "",
exist_overrides: Optional[dict[str, Any]] = None,
) -> PretrainedConfig:
"""
diff --git a/tests/multimodal/test_registry.py b/tests/multimodal/test_registry.py
new file mode 100644
index 0000000000000..d31e75bc279f6
--- /dev/null
+++ b/tests/multimodal/test_registry.py
@@ -0,0 +1,38 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Unit tests for MultiModalRegistry.supports_multimodal_inputs and
+Qwen2.5-VL visual component loading behavior.
+"""
+
+import pytest
+
+from vllm.multimodal import MULTIMODAL_REGISTRY
+
+from ..models.utils import build_model_context
+
+
+@pytest.mark.parametrize(
+ "model_id,limit_mm_per_prompt,expected",
+ [
+ ("Qwen/Qwen2-0.5B-Instruct", {}, False),
+ ("Qwen/Qwen2.5-VL-3B-Instruct", {}, True),
+ ("Qwen/Qwen2.5-VL-3B-Instruct", {
+ "image": 0,
+ "video": 0
+ }, False),
+ ("Qwen/Qwen2.5-VL-3B-Instruct", {
+ "image": 0
+ }, True),
+ ],
+)
+@pytest.mark.core_model
+def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected):
+ """Test supports_multimodal_inputs returns correct boolean for various
+ configs."""
+ ctx = build_model_context(
+ model_id,
+ limit_mm_per_prompt=limit_mm_per_prompt,
+ )
+ assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(
+ ctx.model_config) is expected
\ No newline at end of file
diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
index e67825f89d815..8d0687b49bb47 100644
--- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
+++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
@@ -25,5 +25,6 @@ class DummyPlatform(Platform):
compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(self, backend_name, head_size, dtype,
- kv_cache_dtype, block_size, use_v1, use_mla):
- return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
\ No newline at end of file
+ kv_cache_dtype, block_size, use_v1, use_mla,
+ has_sink):
+ return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py
index c46ac7a88b751..45ddb2178722a 100644
--- a/tests/speculative_decoding/speculators/test_eagle3.py
+++ b/tests/speculative_decoding/speculators/test_eagle3.py
@@ -3,12 +3,20 @@
import pytest
import torch
+from vllm.model_executor.models.interfaces import supports_eagle3
+
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
-def test_llama(vllm_runner, example_prompts, model_path):
+def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
+ # Set environment variable for V1 engine serialization
+ monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
+
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
+ eagle3_supported = vllm_model.apply_model(supports_eagle3)
+ assert eagle3_supported
+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
@@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
@pytest.mark.parametrize(
"model_path",
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
-def test_qwen(vllm_runner, example_prompts, model_path):
+def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
+ # Set environment variable for V1 engine serialization
+ monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
+
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
+ eagle3_supported = vllm_model.apply_model(supports_eagle3)
+ assert eagle3_supported
+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py
index b8d7892e57f21..0fb142a1b6e56 100644
--- a/tests/tensorizer_loader/test_tensorizer.py
+++ b/tests/tensorizer_loader/test_tensorizer.py
@@ -166,7 +166,7 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
combined_output = out + err
assert ("ValueError: Model loader extra config "
"is not supported for load "
- "format LoadFormat.AUTO") in combined_output
+ "format auto") in combined_output
finally:
del model
gc.collect()
@@ -186,7 +186,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd,
combined_output = out + err
assert ("ValueError: Model loader extra config is not supported "
- "for load format LoadFormat.SAFETENSORS") in combined_output
+ "for load format safetensors") in combined_output
finally:
del model
gc.collect()
diff --git a/tests/test_config.py b/tests/test_config.py
index 441c07b99acfa..957771a4226bc 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
assert model_config.max_model_len == expected
-def test_get_sliding_window():
- TEST_SLIDING_WINDOW = 4096
- # Test that the sliding window is correctly computed.
- # For Qwen1.5/Qwen2, get_sliding_window() should be None
- # when use_sliding_window is False.
- qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
-
- qwen2_model_config.hf_config.use_sliding_window = False
- qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
- assert qwen2_model_config.get_sliding_window() is None
-
- qwen2_model_config.hf_config.use_sliding_window = True
- assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
-
- mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
- mistral_model_config.hf_config.sliding_window = None
- assert mistral_model_config.get_sliding_window() is None
-
- mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
- assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
-
-
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
@@ -249,6 +227,20 @@ def test_get_pooling_config_from_args():
assert asdict(pooling_config) == asdict(override_pooler_config)
+@pytest.mark.parametrize(
+ ("model_id", "default_pooling_type", "pooling_type"),
+ [
+ ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
+ ("intfloat/e5-small", "CLS", "MEAN"), # BertModel
+ ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
+ ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward
+ ])
+def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
+ model_config = ModelConfig(model_id)
+ assert model_config._model_info.default_pooling_type == default_pooling_type
+ assert model_config.pooler_config.pooling_type == pooling_type
+
+
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py
index 1bb4203d21c3e..42afdfa3c7468 100644
--- a/tests/test_sharded_state_loader.py
+++ b/tests/test_sharded_state_loader.py
@@ -118,8 +118,17 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
tensor_parallel_size=tp_size,
))
p.start()
- p.join()
+ # Call queue.get() before p.join() to prevent deadlock:
+ # If p.join() is called before queue.get() and the queue is full,
+ # the child process may block while writing to the queue and never
+ # terminate, causing the parent to wait indefinitely on p.join().
+ # See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814
out_before = queue.get()
+ p.join()
+ queue.close()
+ queue.join_thread()
+
+ queue = ctx.Queue()
p = ctx.Process(target=_run_generate,
args=(output_dir, queue),
@@ -131,7 +140,14 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
load_format="sharded_state",
))
p.start()
- p.join()
+ # Call queue.get() before p.join() to prevent deadlock:
+ # If p.join() is called before queue.get() and the queue is full,
+ # the child process may block while writing to the queue and never
+ # terminate, causing the parent to wait indefinitely on p.join().
+ # See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814
out_after = queue.get()
+ p.join()
+ queue.close()
+ queue.join_thread()
assert out_before == out_after
diff --git a/tests/test_test.py b/tests/test_test.py
new file mode 100644
index 0000000000000..dc8c9814ede39
--- /dev/null
+++ b/tests/test_test.py
@@ -0,0 +1,61 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+
+from vllm import LLM, envs
+from vllm.sampling_params import SamplingParams
+
+if not envs.VLLM_USE_V1:
+ pytest.skip(
+ "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
+ allow_module_level=True,
+ )
+
+
+@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
+# TODO TPU will appear busy if we fan-out test params here
+@pytest.mark.parametrize("n_prompts", [1])
+def test_logprobs(model_name: str, n_prompts: int):
+ """
+ Request top logprobs with different sampling settings and check
+ that results contains the requested number, ordered ascendingly.
+ """
+
+ def check_num_logprobs(logprobs, expected_num: int):
+ for step in logprobs:
+ prev_logp = 1.0
+ # order by rank
+ sorted_step = dict(
+ sorted(step.items(), key=lambda item: item[1].rank))
+
+ if len(step) != expected_num:
+ print("watch out", sorted_step)
+
+ # check results are ordered by prob value
+ # assert len(step) == expected_num
+ for rankno, (tid, logp) in enumerate(sorted_step.items()):
+ assert logp.logprob <= prev_logp
+ prev_logp = logp.logprob
+ assert logp.rank == rankno + 1
+
+ llm = LLM(model_name,
+ enforce_eager=False,
+ max_num_seqs=1,
+ max_model_len=128,
+ max_num_batched_tokens=128)
+ prompts = [
+ "Write a short story about a robot that dreams for the first time."
+ ] * n_prompts
+ greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
+ logprobs=4)
+ regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
+ logprobs=4)
+ topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
+ logprobs=4, top_k=12, top_p=0.5)
+
+ for sp in [greedy_sampling_params, regular_sampling_params, \
+ topkp_sampling_params]:
+ output = llm.generate(prompts, sp)
+ for o in output:
+ check_num_logprobs(o.outputs[0].logprobs, 4)
diff --git a/tests/utils.py b/tests/utils.py
index 1c1a1cc6014ec..18fcde949160e 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
import cloudpickle
+import httpx
import openai
import pytest
import requests
@@ -88,10 +89,12 @@ class RemoteOpenAIServer:
raise ValueError("You have manually specified the port "
"when `auto_port=True`.")
- # Don't mutate the input args
- vllm_serve_args = vllm_serve_args + [
- "--port", str(get_open_port())
- ]
+ # No need for a port if using unix sockets
+ if "--uds" not in vllm_serve_args:
+ # Don't mutate the input args
+ vllm_serve_args = vllm_serve_args + [
+ "--port", str(get_open_port())
+ ]
if seed is not None:
if "--seed" in vllm_serve_args:
raise ValueError("You have manually specified the seed "
@@ -104,8 +107,13 @@ class RemoteOpenAIServer:
subparsers = parser.add_subparsers(required=False, dest="subparser")
parser = ServeSubcommand().subparser_init(subparsers)
args = parser.parse_args(["--model", model, *vllm_serve_args])
- self.host = str(args.host or 'localhost')
- self.port = int(args.port)
+ self.uds = args.uds
+ if args.uds:
+ self.host = None
+ self.port = None
+ else:
+ self.host = str(args.host or 'localhost')
+ self.port = int(args.port)
self.show_hidden_metrics = \
args.show_hidden_metrics_for_version is not None
@@ -150,9 +158,11 @@ class RemoteOpenAIServer:
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
+ client = (httpx.Client(transport=httpx.HTTPTransport(
+ uds=self.uds)) if self.uds else requests)
while True:
try:
- if requests.get(url).status_code == 200:
+ if client.get(url).status_code == 200:
break
except Exception:
# this exception can only be raised by requests.get,
@@ -170,7 +180,8 @@ class RemoteOpenAIServer:
@property
def url_root(self) -> str:
- return f"http://{self.host}:{self.port}"
+ return (f"http://{self.uds.split('/')[-1]}"
+ if self.uds else f"http://{self.host}:{self.port}")
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
@@ -986,3 +997,19 @@ def has_module_attribute(module_name, attribute_name):
return hasattr(module, attribute_name)
except ImportError:
return False
+
+
+def get_attn_backend_list_based_on_platform() -> list[str]:
+ if current_platform.is_cuda():
+ return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
+ elif current_platform.is_rocm():
+ attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
+ try:
+ import aiter # noqa: F401
+ attn_backend_list.append("FLASH_ATTN_VLLM_V1")
+ except Exception:
+ print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
+
+ return attn_backend_list
+ else:
+ raise ValueError("Unsupported platform")
diff --git a/tests/utils_/__init__.py b/tests/utils_/__init__.py
new file mode 100644
index 0000000000000..e6b4c3f6364cd
--- /dev/null
+++ b/tests/utils_/__init__.py
@@ -0,0 +1,6 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+This module is named `utils_` instead of `utils` to avoid obscuring
+`tests/utils.py`.
+"""
diff --git a/tests/standalone_tests/test_tensor_schema.py b/tests/utils_/test_tensor_schema.py
similarity index 73%
rename from tests/standalone_tests/test_tensor_schema.py
rename to tests/utils_/test_tensor_schema.py
index e98aa3f53fb53..6aa781c1564de 100644
--- a/tests/standalone_tests/test_tensor_schema.py
+++ b/tests/utils_/test_tensor_schema.py
@@ -4,8 +4,8 @@
import pytest
import torch
-from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
+from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
@@ -33,6 +33,31 @@ def test_tensor_schema_constant_dim_failure():
)
+def test_tensor_schema_invalid_types_in_list():
+ with pytest.raises(ValueError, match="is not a torch.Tensor"):
+ Phi3VImagePixelInputs(
+ data=[
+ torch.randn(64, 3, 32, 32),
+ "not_a_tensor",
+ torch.randn(64, 3, 32, 32),
+ ],
+ image_sizes=torch.randint(0, 256, (3, 2)),
+ )
+
+
+def test_tensor_schema_rank_mismatch():
+ with pytest.raises(ValueError, match="has rank 3 but expected 5"):
+ Phi3VImagePixelInputs(
+ data=torch.randn(16, 64, 3),
+ image_sizes=torch.randint(0, 256, (16, 2)),
+ )
+
+
+def test_tensor_schema_missing_required_field():
+ with pytest.raises(ValueError, match="Required field 'data' is missing"):
+ Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
+
+
def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs(
@@ -129,23 +154,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
def test_tensor_schema_with_list_of_symbolic_dim():
- flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
- patches_per_image = [64, 64, 64] # len = bn = 3
+ input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160)
+ input_features_mask = torch.randn(3, 8) # (b=3, fo=8)
+ audio_embed_sizes = [8, 8, 8] # len = b = 3
- FuyuImagePatchInputs(
- flat_data=flat_data,
- patches_per_image=patches_per_image,
+ GraniteSpeechAudioInputs(
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ audio_embed_sizes=audio_embed_sizes,
)
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
- flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
- patches_per_image = [64, 64, 64] # len = 3 ≠ bn
+ input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160)
+ input_features_mask = torch.randn(4, 8) # (b=4, fo=8)
+ audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b
- with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
- FuyuImagePatchInputs(
- flat_data=flat_data,
- patches_per_image=patches_per_image,
+ with pytest.raises(ValueError, match="expected 'b'=4, got 3"):
+ GraniteSpeechAudioInputs(
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ audio_embed_sizes=audio_embed_sizes,
)
diff --git a/tests/test_utils.py b/tests/utils_/test_utils.py
similarity index 99%
rename from tests/test_utils.py
rename to tests/utils_/test_utils.py
index 53a34642e5baf..a2db1ae684341 100644
--- a/tests/test_utils.py
+++ b/tests/utils_/test_utils.py
@@ -5,7 +5,6 @@
import asyncio
import hashlib
import json
-import logging
import pickle
import socket
from collections.abc import AsyncIterator
@@ -29,7 +28,7 @@ from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
-from .utils import create_new_process_for_each_test, error_on_warning
+from ..utils import create_new_process_for_each_test, error_on_warning
@pytest.mark.asyncio
diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py
index e9e574501d63e..a4e38eb32f6a1 100644
--- a/tests/v1/attention/utils.py
+++ b/tests/v1/attention/utils.py
@@ -11,7 +11,7 @@ import torch
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig)
-from vllm.platforms import _Backend
+from vllm.platforms import _Backend, current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
@@ -119,7 +119,10 @@ def get_attention_backend(backend_name: _Backend):
"""
backend_map = {
_Backend.FLASH_ATTN_VLLM_V1:
- "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
+ ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
+ if current_platform.is_cuda() else
+ "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
+ ),
_Backend.FLASHINFER_VLLM_V1:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION:
diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py
index 31f25e94c5b4b..599916c0d1cfb 100644
--- a/tests/v1/e2e/test_spec_decode.py
+++ b/tests/v1/e2e/test_spec_decode.py
@@ -8,10 +8,12 @@ from typing import Any, Union
import pytest
import torch
+from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
+from vllm.platforms import current_platform
def get_test_prompts(mm_enabled: bool):
@@ -124,7 +126,10 @@ def test_ngram_correctness(
@pytest.mark.parametrize(
- ["model_setup", "mm_enabled"], [
+ ["model_setup", "mm_enabled"],
+ [
+ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
+ # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
@@ -140,12 +145,18 @@ def test_ngram_correctness(
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
- ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
+ ids=[
+ "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
+ "llama4_eagle_mm"
+ ])
+@pytest.mark.parametrize("attn_backend",
+ get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
+ attn_backend: str,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
@@ -156,6 +167,16 @@ def test_eagle_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
+ m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
+
+ if (attn_backend == "TRITON_ATTN_VLLM_V1"
+ and not current_platform.is_rocm()):
+ pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
+ "multi-token eagle spec decode on current platform")
+
+ if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
+ m.setenv("VLLM_ROCM_USE_AITER", "1")
+
method, model_name, spec_model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,
diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py
new file mode 100644
index 0000000000000..41f1d02bf7870
--- /dev/null
+++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py
@@ -0,0 +1,104 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import base64
+import io
+import json
+
+import openai # use the official client for correctness check
+import pytest
+import pytest_asyncio
+import torch
+from transformers import AutoConfig
+
+from tests.conftest import ImageTestAssets
+from tests.utils import RemoteOpenAIServer
+
+# any model with a chat template should work here
+MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
+CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
+MAXIMUM_IMAGES = 2
+
+
+@pytest.fixture(scope="module")
+def default_image_embeds_server_args() -> list[str]:
+ return [
+ "--dtype",
+ "bfloat16",
+ "--max-model-len",
+ "2048",
+ "--max-num-seqs",
+ "4",
+ "--enforce-eager",
+ "--limit-mm-per-prompt",
+ json.dumps({"image": MAXIMUM_IMAGES}),
+ ]
+
+
+@pytest.fixture(scope="module")
+def server_with_image_embeds(default_image_embeds_server_args):
+ with RemoteOpenAIServer(MODEL_NAME,
+ default_image_embeds_server_args,
+ max_wait_seconds=600) as remote_server:
+ yield remote_server
+
+
+@pytest_asyncio.fixture
+async def client_with_image_embeds(server_with_image_embeds):
+ async with server_with_image_embeds.get_async_client() as async_client:
+ yield async_client
+
+
+def encode_image_embedding_to_base64(image_embedding) -> str:
+ """
+ Encode image embedding to base64 string
+ """
+ buffer = io.BytesIO()
+ torch.save(image_embedding, buffer)
+ buffer.seek(0)
+ binary_data = buffer.read()
+ base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
+ return base64_image_embedding
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
+async def test_completions_with_image_embeds(
+ client_with_image_embeds: openai.AsyncOpenAI,
+ model_name: str,
+ image_assets: ImageTestAssets,
+ dtype: torch.dtype,
+):
+ # Test case: Single image embeds input
+ image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
+ base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
+ chat_completion = await client_with_image_embeds.chat.completions.create(
+ messages=[
+ {
+ "role": "system",
+ "content": "You are a helpful assistant."
+ },
+ {
+ "role":
+ "user",
+ "content": [
+ {
+ "type":
+ "text",
+ "text":
+ "Describe these images separately. For each image,"
+ "reply with a short sentence (no more than 10 words).",
+ },
+ {
+ "type": "image_embeds",
+ "image_embeds": base64_image_embedding,
+ },
+ ],
+ },
+ ],
+ model=model_name,
+ )
+ assert chat_completion.choices[0].message.content is not None
+ assert isinstance(chat_completion.choices[0].message.content, str)
+ assert len(chat_completion.choices[0].message.content) > 0
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py
index c5ca7df836853..c6739832355fe 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector.py
@@ -173,9 +173,9 @@ def test_prompt_less_than_block_size():
"""
Test that we can handle case where prompt is < block.
- In this case, the P worker will send empty remote_block_ids.
- The D worker should not schedule an async read in this case,
- since there is nothing to pull.
+ In this case, the P worker will still send remote_block_ids of the
+ partial block. The D worker should schedule an async read
+ in this case.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
@@ -184,22 +184,20 @@ def test_prompt_less_than_block_size():
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
- # Request will have 0 remote blocks.
+ # Request will have 1 partial remote block.
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
- num_remote_blocks=0)
+ num_remote_blocks=1)
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
- # This request should not have to read async.
+ # This request will read async.
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
- assert len(kv_connector_metadata.reqs_to_recv) == 0
-
- # This request should be scheduled regularly.
- assert len(scheduler_output.scheduled_new_reqs) == 1
+ assert len(kv_connector_metadata.reqs_to_recv) == 1
+ assert len(scheduler_output.scheduled_new_reqs) == 0
class FakeNixlConnectorWorker(NixlConnectorWorker):
diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
index 76394a540aacd..2f8228864e7b4 100644
--- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
@@ -121,13 +121,19 @@ def test_short_prompt_lifecycle():
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
- # Since tokens < block_size, there will be no kv xfer.
- # So this should be cleaned up immediately.
- _ = scheduler.update_from_output(scheduler_output, model_runner_output)
+ # Even though tokens < block_size, there will be kv xfer for partial block.
+ eco = scheduler.update_from_output(scheduler_output, model_runner_output)
+ kv_transfer_params = eco[0].outputs[0].kv_transfer_params
+
+ assert (len(kv_transfer_params["remote_block_ids"]) == 1)
# Confirm we do not have any memory leaks after req lifecycle.
- # We need one more call to schedule() to clear data for persistent batch.
- _ = scheduler.schedule()
+ # We need to mark sending finish to clear data for persistent batch.
+ scheduler_output = scheduler.schedule()
+ # Use create_model_runner_output to pass kv_connector_output along
+ model_runner_output = create_model_runner_output(
+ reqs=[request], finished_sending=[request.request_id])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
@@ -169,16 +175,16 @@ def test_prefix_cache_lifecycle():
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
- # Ensure we send all block ids, even if there is a cache hit.
+ # Ensure we send all block ids, including the partial blocks,
+ # even if there is a cache hit.
assert (len(
- kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS)
+ kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
+ 1))
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
- scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
- _ = scheduler.schedule()
assert_scheduler_empty(scheduler)
diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
index 3d52ea526d96b..87f7490698a31 100644
--- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
@@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
BLOCK_SIZE = vllm_config.cache_config.block_size
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
- NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
+ NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
@@ -393,14 +393,24 @@ def test_cannot_schedule_after_recv():
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
- # Step 4: try to schedule, not enough blocks.
+ # Step 4: try to schedule, remote request is put to running list
+ # because the transfer is completed.
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(
+ reqs=[request_normal, request_remote])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 2
+ assert len(scheduler.waiting) == 0
+
+ # Step 5: Remote request will be put back to waiting list
+ # because it needs new block to hold generated token.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
- # Step 5: finish the request, free it.
+ # Step 6: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
@@ -408,15 +418,99 @@ def test_cannot_schedule_after_recv():
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
- # Step 6: now we can schedule (with 2 blocks computed).
+ # Step 7: now we can schedule (with 2 blocks computed),
+ # request is retrieved from preempted list.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
- assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
+ assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ==
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
+ # Step 8: free everything.
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(reqs=[request_remote],
+ use_eos=True)
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ _ = scheduler.schedule()
+ assert_scheduler_empty(scheduler)
+
+
+def test_cannot_recv():
+ """
+ Test that we can handle no schedule KV block transfer due to not
+ enough remaining KV blocks.
+ """
+
+ # NOTE: the KVCacheManager will use 1 null block.
+ # So there are 5 total working blocks.
+ TOTAL_NUM_BLOCKS = 6
+ vllm_config = create_vllm_config()
+ scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
+
+ # Prime the KVCache.
+ NUM_PROMPT_BLOCKS = 2
+ BLOCK_SIZE = vllm_config.cache_config.block_size
+ # Prompt will use 2 blocks + 1 block after we schedule.
+ NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
+ NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
+
+ request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
+ request_remote = create_request(request_id=2,
+ num_tokens=NUM_TOKENS_REMOTE,
+ do_remote_prefill=True)
+
+ # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
+ scheduler.add_request(request_normal)
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(reqs=[request_normal])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 1
+ assert len(scheduler.waiting) == 0
+
+ # Step 2: 3 blocks are in use,
+ # need 3 new for remote blocks but only 2 are available.
+ scheduler.add_request(request_remote)
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(reqs=[request_normal])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 1
+ assert len(scheduler.waiting) == 1
+ # Should not have KV transfer in progress.
+ assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS)
+
+ # Step 3: finish the request, free it.
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(reqs=[request_normal],
+ use_eos=True)
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 0
+ assert len(scheduler.waiting) == 1
+
+ # Step 4: now we can initiate KV transfer (with 2 blocks computed).
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(reqs=[])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 0
+ assert len(scheduler.waiting) == 1
+ assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
+
+ # Step 5: finish recving (5 blocks in use)
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(
+ reqs=[], finished_recving=[request_remote.request_id])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 0
+ assert len(scheduler.waiting) == 1
+
+ # Step 6: schedule remote request
+ scheduler_output = scheduler.schedule()
+ model_runner_output = create_model_runner_output(reqs=[request_remote])
+ scheduler.update_from_output(scheduler_output, model_runner_output)
+ assert len(scheduler.running) == 1
+ assert len(scheduler.waiting) == 0
+
# Step 7: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote],
diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py
index 291c84d117cb6..c22d5b861e3f4 100644
--- a/tests/v1/kv_connector/unit/utils.py
+++ b/tests/v1/kv_connector/unit/utils.py
@@ -179,6 +179,13 @@ def create_model_runner_output(
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
+ kv_connector_output = None if (
+ finished_sending is None
+ and finished_recving is None) else KVConnectorOutput(
+ finished_sending=finished_sending,
+ finished_recving=finished_recving,
+ )
+
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
@@ -188,10 +195,7 @@ def create_model_runner_output(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
- kv_connector_output=KVConnectorOutput(
- finished_sending=finished_sending,
- finished_recving=finished_recving,
- ),
+ kv_connector_output=kv_connector_output,
)
diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py
index ea10661ea1137..31c6c881d7b83 100644
--- a/tests/v1/sample/test_sampler.py
+++ b/tests/v1/sample/test_sampler.py
@@ -90,6 +90,27 @@ def _create_bad_words_token_ids(
return bad_words_token_ids
+# Returns all last tokens of bad word sequences that share the same prefix
+# as `given_prefix` (excluding the last token).
+def _collect_suffixes_with_same_prefix(
+ given_prefix: list[int],
+ bad_words_token_ids: list[list[int]]) -> list[int]:
+ return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
+
+
+# generate a valid token id that is not in bad_words_token_ids
+def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
+ vocab_size: int) -> int:
+ forbidden_start_tokens = set()
+ for bad_word in bad_words_token_ids:
+ forbidden_start_tokens.add(bad_word[0])
+ # Get a safe token that's not in forbidden starts
+ safe_token_candidates = list(
+ set(range(vocab_size)) - forbidden_start_tokens)
+ # Pick a random safe token
+ return np.random.choice(safe_token_candidates)
+
+
def _update_output_token_ids_for_bad_words(
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
bad_words_last_tokens = {}
@@ -104,12 +125,17 @@ def _update_output_token_ids_for_bad_words(
prefix_length = len(bad_word_token_ids) - 1
has_bad_words = np.random.choice([True, False])
if has_bad_words:
- output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
- bad_words_last_token.append(bad_word_token_ids[-1])
+ prefix = bad_word_token_ids[:-1]
+ output_token_ids[-prefix_length:] = prefix
+ # Collect all last tokens from other bad words
+ # that share this prefix
+ bad_words_last_token.extend(
+ _collect_suffixes_with_same_prefix(
+ prefix, bad_words_token_ids))
break # Maximum one update to output_token_ids
else: # Make sure no accidental match to bad words
- output_token_ids[-1] = (bad_word_token_ids[-2] +
- 1) % vocab_size
+ output_token_ids[-1] = _generate_valid_token_id(
+ bad_words_token_ids, vocab_size)
bad_words_last_tokens[batch_idx] = bad_words_last_token
return bad_words_last_tokens
diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py
index 73b47f8974397..2b4f8bd2a8b90 100644
--- a/tests/v1/spec_decode/test_eagle.py
+++ b/tests/v1/spec_decode/test_eagle.py
@@ -6,6 +6,7 @@ from unittest import mock
import pytest
import torch
+from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
@@ -120,17 +121,28 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)
-@pytest.mark.parametrize("method,proposer_helper", [
- ("eagle", lambda k: _create_proposer("eagle", k)),
- ("eagle3", lambda k: _create_proposer("eagle3", k)),
-])
+@pytest.mark.parametrize("method", ["eagle", "eagle3"])
+@pytest.mark.parametrize("attn_backend",
+ get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
- proposer_helper, pp_size, use_distinct_embed_tokens):
+ attn_backend, pp_size, use_distinct_embed_tokens,
+ monkeypatch):
+
+ monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
+
+ if (attn_backend == "TRITON_ATTN_VLLM_V1"
+ and not current_platform.is_rocm()):
+ pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
+ "multi-token eagle spec decode on current platform")
+
+ if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
+
# Setup draft model mock
mock_model = mock.MagicMock()
if use_distinct_embed_tokens:
@@ -177,7 +189,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model.lm_head = mock.MagicMock()
# Create proposer using the helper function
- proposer = proposer_helper(k=8)
+ proposer = _create_proposer(method, k=8)
# Call the method under test
proposer.load_model(target_model)
@@ -201,10 +213,22 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model.model.embed_tokens
+@pytest.mark.parametrize("method", ["eagle", "eagle3"])
+@pytest.mark.parametrize("attn_backend",
+ get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
-@pytest.mark.parametrize("backend",
- [_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
-def test_propose(num_speculative_tokens, backend):
+def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
+
+ monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
+
+ if (attn_backend == "TRITON_ATTN_VLLM_V1"
+ and not current_platform.is_rocm()):
+ pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
+ "multi-token eagle spec decode on current platform")
+
+ if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
+
# Use GPU device
device = torch.device(current_platform.device_type)
@@ -303,7 +327,18 @@ def test_propose(num_speculative_tokens, backend):
device=device)
sampling_metadata = mock.MagicMock()
- attn_metadata_builder_cls, _ = get_attention_backend(backend)
+ if attn_backend == "FLASH_ATTN_VLLM_V1":
+ attn_metadata_builder_cls, _ = get_attention_backend(
+ _Backend.FLASH_ATTN_VLLM_V1)
+ elif attn_backend == "TRITON_ATTN_VLLM_V1":
+ attn_metadata_builder_cls, _ = get_attention_backend(
+ _Backend.TRITON_ATTN_VLLM_V1)
+ elif attn_backend == "TREE_ATTN":
+ attn_metadata_builder_cls, _ = get_attention_backend(
+ _Backend.TREE_ATTN)
+ else:
+ raise ValueError(f"Unsupported attention backend: {attn_backend}")
+
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py
index 9070d2b10f8b5..01019b29e010a 100644
--- a/tests/v1/spec_decode/test_max_len.py
+++ b/tests/v1/spec_decode/test_max_len.py
@@ -4,7 +4,9 @@
import pytest
+from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
+from vllm.platforms import current_platform
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
@@ -14,36 +16,45 @@ _PROMPTS = [
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
-def test_ngram_max_len(
- monkeypatch: pytest.MonkeyPatch,
- num_speculative_tokens: int,
-):
- with monkeypatch.context() as m:
- m.setenv("VLLM_USE_V1", "1")
-
- llm = LLM(
- model="facebook/opt-125m",
- max_model_len=100,
- enforce_eager=True, # For faster initialization.
- speculative_config={
- "method": "ngram",
- "prompt_lookup_max": 5,
- "prompt_lookup_min": 3,
- "num_speculative_tokens": num_speculative_tokens,
- },
- )
- sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
- llm.generate(_PROMPTS, sampling_params)
+def test_ngram_max_len(num_speculative_tokens: int):
+ llm = LLM(
+ model="facebook/opt-125m",
+ max_model_len=100,
+ enforce_eager=True, # For faster initialization.
+ speculative_config={
+ "method": "ngram",
+ "prompt_lookup_max": 5,
+ "prompt_lookup_min": 3,
+ "num_speculative_tokens": num_speculative_tokens,
+ },
+ )
+ sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
+ llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
-def test_eagle_max_len(
- monkeypatch: pytest.MonkeyPatch,
- num_speculative_tokens: int,
-):
+@pytest.mark.parametrize("attn_backend",
+ get_attn_backend_list_based_on_platform())
+def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
+ num_speculative_tokens: int, attn_backend: str):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
+ if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1:
+ # TREE_ATTN fails the test with multi-token spec decode
+ # TODO: Investigate why
+ pytest.skip("TREE_ATTN fails the test")
+
+ m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
+
+ if (attn_backend == "TRITON_ATTN_VLLM_V1"
+ and not current_platform.is_rocm()):
+ pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
+ "multi-token eagle spec decode on current platform")
+
+ if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
+ m.setenv("VLLM_ROCM_USE_AITER", "1")
+
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py
index f82737325e9b8..acb607247d754 100644
--- a/tests/v1/tpu/test_kv_cache_update_kernel.py
+++ b/tests/v1/tpu/test_kv_cache_update_kernel.py
@@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
np.cumsum(slice_lens[:-1])])
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
- padded_size = (slot_mapping.shape[0] + num_slices_per_block -
- 1) // num_slices_per_block * num_slices_per_block
- slot_mapping = np.pad(slot_mapping,
- [[0, padded_size - slot_mapping.shape[0]], [0, 0]],
- constant_values=0)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping,
device="cpu",
diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py
new file mode 100644
index 0000000000000..991070dc9239d
--- /dev/null
+++ b/tests/v1/tpu/test_tpu_int8.py
@@ -0,0 +1,73 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests whether TPU Int8 computation is enabled correctly.
+
+Run `pytest tests/quantization/test_tpu_int8.py`.
+"""
+import pytest
+
+from vllm.model_executor.layers.linear import LinearBase
+from vllm.model_executor.layers.quantization.tpu_int8 import (
+ TPUInt8LinearMethod)
+from vllm.platforms import current_platform
+
+from ...models.registry import HF_EXAMPLE_MODELS
+
+MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
+
+
+@pytest.mark.skipif(not current_platform.is_tpu(),
+ reason="TPU Int8 is only enabled for TPUs.")
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["bfloat16"])
+@pytest.mark.parametrize("max_tokens", [10])
+@pytest.mark.parametrize(
+ "hf_overrides",
+ [
+ # w8a8 dynamic activation
+ {
+ 'quantization_config': {
+ 'quant_method': 'tpu_int8',
+ 'activation_scheme': 'dynamic'
+ }
+ }
+ ])
+def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
+ hf_overrides: dict, monkeypatch) -> None:
+ model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
+ model_info.check_transformers_version(on_fail="skip")
+
+ activation_scheme = hf_overrides.get('quantization_config',
+ {}).get('activation_scheme')
+ quantize_activation = activation_scheme == 'dynamic'
+
+ # Allows using apply_model
+ monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
+ # Prevent error from re-initializing cache
+ monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
+
+ prompts = [
+ "A robot may not injure a human being",
+ "It is only with the heart that one can see rightly;",
+ "The greatest glory in living lies not in never falling,",
+ ]
+ answers = [
+ "or, being injured, not kill, except in",
+ "without the heart, one can only see wrongly.",
+ "but in rising every time we fall. - Nelson"
+ ]
+
+ with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
+
+ def check_model(model):
+ for name, module in model.named_modules():
+ if not isinstance(module, LinearBase):
+ continue
+ quant_method = module.quant_method
+ assert isinstance(quant_method, TPUInt8LinearMethod)
+ assert quant_method.quantize_activation == quantize_activation
+
+ vllm.apply_model(check_model)
+ outputs = vllm.generate_greedy(prompts, max_tokens)
+ for (_, output), answer in zip(outputs, answers):
+ assert answer in output
diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py
index 5e99dc63ebe0c..444e2bf53f995 100644
--- a/tools/check_pickle_imports.py
+++ b/tools/check_pickle_imports.py
@@ -32,7 +32,7 @@ ALLOWED_FILES = set([
'vllm/multimodal/hasher.py',
'vllm/transformers_utils/config.py',
'vllm/model_executor/models/registry.py',
- 'tests/test_utils.py',
+ 'tests/utils_/test_utils.py',
'tests/tokenization/test_cached_tokenizer.py',
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 92de39418054b..70605d3c5f52c 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -311,7 +311,7 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
- if current_platform.is_cuda() and logits.is_contiguous():
+ if logits.is_cuda and logits.is_contiguous():
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
repetition_penalties)
else:
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index b4c3cbd7c9d64..1a9c0e26b53ca 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -138,6 +138,7 @@ class Attention(nn.Module):
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
+ self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
@@ -165,7 +166,8 @@ class Attention(nn.Module):
kv_cache_dtype,
block_size,
is_attention_free,
- use_mla=use_mla)
+ use_mla=use_mla,
+ has_sink=self.has_sink)
else:
self.attn_backend = attn_backend
diff --git a/vllm/attention/layers/__init__.py b/vllm/attention/layers/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py
index b85f27ac417cf..1af26dfc3daa3 100644
--- a/vllm/attention/ops/flashmla.py
+++ b/vllm/attention/ops/flashmla.py
@@ -91,7 +91,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
- None,
head_dim_v,
cache_seqlens,
block_table,
diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py
index e7d727a45e91c..d75983bd407d0 100644
--- a/vllm/attention/ops/pallas_kv_cache_update.py
+++ b/vllm/attention/ops/pallas_kv_cache_update.py
@@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
+ num_slices_ref, # [1]
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
@@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
- new_kv_start = slices_ref[1, offset_i]
- length = slices_ref[2, offset_i]
+ new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
+ slices_ref[1, offset_i], 0)
+ length = jax.lax.select(offset_i < num_slices_ref[0],
+ slices_ref[2, offset_i], 0)
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
@@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
- kv_cache_start = slices_ref[0, offset_i]
- length = slices_ref[2, offset_i]
+ kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
+ slices_ref[0, offset_i], 0)
+ length = jax.lax.select(offset_i < num_slices_ref[0],
+ slices_ref[2, offset_i], 0)
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
@@ -77,7 +82,6 @@ def kv_cache_update(
page_size: int = 32,
num_slices_per_block: int = 8,
):
- assert slices.shape[1] % num_slices_per_block == 0
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
@@ -93,7 +97,7 @@ def kv_cache_update(
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
- scalar_prefetches = [slices]
+ scalar_prefetches = [slices, num_kv_update_slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index 508470bb363e1..3a235ba6e0b42 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -144,6 +144,7 @@ def get_attn_backend(
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
+ has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@@ -158,6 +159,7 @@ def get_attn_backend(
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
+ has_sink=has_sink,
)
@@ -170,6 +172,7 @@ def _cached_get_attn_backend(
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
+ has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
@@ -201,7 +204,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
- use_mla)
+ use_mla, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")
diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py
index 45b58035ebe32..4e8ac5162542f 100644
--- a/vllm/benchmarks/datasets.py
+++ b/vllm/benchmarks/datasets.py
@@ -71,7 +71,9 @@ class SampleRequest:
prompt: Union[str, Any]
prompt_len: int
expected_output_len: int
- multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
+ multi_modal_data: Optional[
+ Union[MultiModalDataDict, dict, list[dict]]
+ ] = None
lora_request: Optional[LoRARequest] = None
diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py
index 2d64cc115f00f..47bc288774504 100644
--- a/vllm/benchmarks/lib/endpoint_request_func.py
+++ b/vllm/benchmarks/lib/endpoint_request_func.py
@@ -28,7 +28,7 @@ class RequestFuncInput:
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
- multi_modal_content: Optional[dict] = None
+ multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False
language: Optional[str] = None
@@ -172,7 +172,16 @@ async def async_request_openai_chat_completions(
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
- content.append(request_func_input.multi_modal_content)
+ mm_content = request_func_input.multi_modal_content
+ if isinstance(mm_content, list):
+ content.extend(mm_content)
+ elif isinstance(mm_content, dict):
+ content.append(mm_content)
+ else:
+ raise TypeError(
+ "multi_modal_content must be a dict or list[dict] "
+ "for openai-chat"
+ )
payload = {
"model":
request_func_input.model_name
@@ -310,7 +319,10 @@ async def async_request_openai_audio(
buffer.seek(0)
return buffer
- with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
+ mm_audio = request_func_input.multi_modal_content
+ if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
+ raise TypeError("multi_modal_content must be a dict containing 'audio'")
+ with to_bytes(*mm_audio["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py
index 6d52b51a9fcd0..7bf04c7532411 100644
--- a/vllm/benchmarks/serve.py
+++ b/vllm/benchmarks/serve.py
@@ -365,7 +365,14 @@ async def benchmark(
input_requests[0].multi_modal_data,
)
- assert test_mm_content is None or isinstance(test_mm_content, dict)
+ assert (
+ test_mm_content is None
+ or isinstance(test_mm_content, dict)
+ or (
+ isinstance(test_mm_content, list)
+ and all(isinstance(item, dict) for item in test_mm_content)
+ )
+ ), "multi_modal_data must be a dict or list[dict]"
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
@@ -665,7 +672,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={k: [results[k]]
- for k in metrics},
+ for k in metrics if k in results},
extra_info={
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py
index bbd18ca3ae22e..fdf6548ada5b6 100644
--- a/vllm/benchmarks/throughput.py
+++ b/vllm/benchmarks/throughput.py
@@ -24,8 +24,6 @@ from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset,
from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format,
write_to_json)
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
-from vllm.entrypoints.openai.api_server import (
- build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
@@ -146,6 +144,8 @@ async def run_vllm_async(
disable_detokenize: bool = False,
) -> float:
from vllm import SamplingParams
+ from vllm.entrypoints.openai.api_server import (
+ build_async_engine_client_from_engine_args)
async with build_async_engine_client_from_engine_args(
engine_args,
diff --git a/vllm/config.py b/vllm/config/__init__.py
similarity index 72%
rename from vllm/config.py
rename to vllm/config/__init__.py
index 7147702eddde2..df4eb33f5d45d 100644
--- a/vllm/config.py
+++ b/vllm/config/__init__.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# ruff: noqa: F401
import ast
import copy
import enum
@@ -10,11 +11,9 @@ import json
import textwrap
import uuid
import warnings
-from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
-from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
- replace)
+from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from functools import cached_property, lru_cache
from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
@@ -22,16 +21,21 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
import regex as re
import torch
-from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
+from pydantic import (ConfigDict, SkipValidation, field_validator,
model_validator)
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
-from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self, assert_never, runtime_checkable
import vllm.envs as envs
from vllm import version
-from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
+from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
+ PrefixCachingHashAlgo)
+from vllm.config.compilation import (CompilationConfig, CompilationLevel,
+ PassConfig)
+from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
+from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
+from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import current_platform
@@ -39,51 +43,35 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
- maybe_override_with_speculators_target_model, try_get_generation_config,
- try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
+ is_interleaved, maybe_override_with_speculators_target_model,
+ try_get_generation_config, try_get_safetensors_metadata,
+ try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
-# yapf conflicts with isort for this block
-# yapf: disable
-from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
- MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
- POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
- LayerBlockType, LazyLoader, common_broadcastable_dtype,
- cuda_device_count_stateless, get_cpu_memory,
- get_open_port, is_torch_equal_or_newer, random_uuid,
- resolve_obj_by_qualname)
-
-# yapf: enable
+from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType,
+ LazyLoader, common_broadcastable_dtype, random_uuid)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
- from ray.runtime_env import RuntimeEnv
- from ray.util.placement_group import PlacementGroup
from transformers.configuration_utils import PretrainedConfig
import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
- from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader import LoadFormats
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
- ConfigType = type[DataclassInstance]
HfOverrides = Union[dict, Callable[[type], type]]
else:
DataclassInstance = Any
- PlacementGroup = Any
- RuntimeEnv = Any
PretrainedConfig = Any
- ExecutorBase = Any
QuantizationConfig = Any
QuantizationMethods = Any
BaseModelLoader = Any
LoadFormats = Any
TensorizerConfig = Any
- ConfigType = type
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
me_quant = LazyLoader("model_executor", globals(),
@@ -93,7 +81,6 @@ else:
logger = init_logger(__name__)
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
-ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription", "draft"]
@@ -234,23 +221,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
-def config(cls: ConfigT) -> ConfigT:
- """
- A decorator that ensures all fields in a dataclass have default values
- and that each field has a docstring.
-
- If a `ConfigT` is used as a CLI argument itself, the default value provided
- by `get_kwargs` will be the result parsing a JSON string as the kwargs
- (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
- requires custom construction from CLI (i.e. `CompilationConfig`), it can
- have a `from_cli` method, which will be called instead.
-
- Config validation is performed by the tools/validate_config.py
- script, which is invoked during the pre-commit checks.
- """
- return cls
-
-
def get_field(cls: ConfigType, name: str) -> Field:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
@@ -741,53 +711,31 @@ class ModelConfig:
revision=self.revision,
)
- # Workaround for Gemma 2 which uses interleaved sliding window
- # attention, but it's not specified in its config.
- # TODO: remove this when Gemma 2 config updated in HuggingFace.
- if self.hf_text_config.model_type == "gemma2":
- self.hf_text_config.sliding_window_pattern = 2
-
- # TODO: remove this when Gemma 3n config updated in HuggingFace.
- if self.hf_text_config.model_type == "gemma3n_text":
- # 4 sliding window attention followed by 1 full attention
- self.hf_text_config.sliding_window_pattern = "LLLLG"
-
- sliding_window = getattr(self.hf_text_config, "sliding_window", None)
- sliding_window_pattern = getattr(self.hf_text_config,
- "sliding_window_pattern", None)
- has_interleaved_attention = sliding_window_pattern is not None or (
- isinstance(sliding_window, list))
-
- if not self.disable_sliding_window and has_interleaved_attention:
- if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
- ) in ("XFORMERS", "FLASHINFER"):
- sliding_window_len_min = get_min_sliding_window(
- self.hf_text_config.sliding_window)
-
- logger.warning_once(
- "%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
- self.hf_text_config.model_type,
- backend,
- sliding_window_len_min,
- )
- self.disable_sliding_window = True
- else:
- # for a model with interleaved attention,
- # the scheduler and the model treat it as full attention
- # (i.e., not dropping any tokens outside the window).
- # only the attention layer itself is aware of the sliding
- # window, and use the window size to compute the attention.
- self.hf_text_config.interleaved_sliding_window = sliding_window
-
- if hasattr(self.hf_text_config, "sliding_window"):
- delattr(self.hf_text_config, "sliding_window")
-
- sliding_window = None
+ # Interleaved attention is not supported by some backends in V0
+ if (not self.disable_sliding_window
+ and is_interleaved(self.hf_text_config)
+ and not envs.VLLM_USE_V1
+ and (backend := envs.VLLM_ATTENTION_BACKEND)
+ in ("XFORMERS", "FLASHINFER")):
+ logger.warning_once(
+ "%s has interleaved attention, which is currently not "
+ "supported by the %s backend. Disabling sliding window and "
+ "capping the max length to the sliding window size (%d).",
+ self.hf_text_config.model_type,
+ backend,
+ self.hf_text_config.sliding_window,
+ )
+ self.disable_sliding_window = True
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()
+ if self.disable_sliding_window:
+ # Set after get_and_verify_max_len to ensure that max_model_len
+ # can be correctly capped to sliding window size
+ self.hf_text_config.sliding_window = None
+
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@@ -918,6 +866,10 @@ class ModelConfig:
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
+ default_pooling_type = self._model_info.default_pooling_type
+ if pooler_config.pooling_type is None:
+ pooler_config.pooling_type = default_pooling_type
+
return pooler_config
return None
@@ -1212,8 +1164,18 @@ class ModelConfig:
"non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
+ # The `max_seq_len_to_capture` was incorrectly
+ # based on the encoder's input length (448)
+ # but not the decoder's larger input length (1500).
+ # This change ensures the CUDA Graph captures the correct,
+ # larger sequence length, allowing it to work as intended.
+ effective_max_seq_len = self.max_model_len
+ if self.is_encoder_decoder:
+ effective_max_seq_len = max(
+ effective_max_seq_len,
+ getattr(self.hf_config, "max_source_positions", 0))
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
- self.max_model_len)
+ effective_max_seq_len)
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
ROCM_UNSUPPORTED_MODELS = ['mllama']
unsupported_rocm = (self.hf_config.model_type
@@ -1349,27 +1311,10 @@ class ModelConfig:
if self.use_async_output_proc:
self.use_async_output_proc = False
- def get_hf_config_sliding_window(
- self) -> Union[Optional[int], list[Optional[int]]]:
- """Get the sliding window size, or None if disabled."""
-
- # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
- # addition to sliding window size. We check if that field is present
- # and if it's False, return None.
- if (hasattr(self.hf_text_config, "use_sliding_window")
- and not self.hf_text_config.use_sliding_window):
- return None
+ def get_sliding_window(self) -> Optional[int]:
+ """Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)
- def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
- """Get the sliding window size, or None if disabled.
- """
- # If user disables sliding window, return None.
- if self.disable_sliding_window:
- return None
- # Otherwise get the value from the hf config.
- return self.get_hf_config_sliding_window()
-
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)
@@ -1715,15 +1660,6 @@ class ModelConfig:
return mm_config.mm_processor_cache_gb > 0
- @property
- def enable_mm_input_cache(self) -> bool:
- """Whether the multi-modal input cache should be enabled."""
- mm_config = self.multimodal_config
- if mm_config is None:
- return False
-
- return mm_config.mm_processor_cache_gb > 0
-
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
@@ -1798,196 +1734,13 @@ class ModelConfig:
tokenizer_config=tokenizer_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
- sliding_window_len=self.get_hf_config_sliding_window(),
+ sliding_window=self.get_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
logger.info("Using max model len %s", max_model_len)
return max_model_len
-BlockSize = Literal[1, 8, 16, 32, 64, 128]
-CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
-PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
-
-
-@config
-@dataclass
-class CacheConfig:
- """Configuration for the KV cache."""
-
- block_size: SkipValidation[BlockSize] = None # type: ignore
- """Size of a contiguous cache block in number of tokens. This is ignored on
- neuron devices and set to `--max-model-len`. On CUDA devices, only block
- sizes up to 32 are supported. On HPU devices, block size defaults to 128.
-
- This config has no static default. If left unspecified by the user, it will
- be set in `Platform.check_and_update_config()` based on the current
- platform."""
- gpu_memory_utilization: float = 0.9
- """The fraction of GPU memory to be used for the model executor, which can
- range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
- utilization. If unspecified, will use the default value of 0.9. This is a
- per-instance limit, and only applies to the current vLLM instance. It does
- not matter if you have another vLLM instance running on the same GPU. For
- example, if you have two vLLM instances running on the same GPU, you can
- set the GPU memory utilization to 0.5 for each instance."""
- swap_space: float = 4
- """Size of the CPU swap space per GPU (in GiB)."""
- cache_dtype: CacheDType = "auto"
- """Data type for kv cache storage. If "auto", will use model data type.
- CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
- fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
- is_attention_free: bool = False
- """Whether the model is attention-free. This is primarily set in
- `ModelConfig` and that value should be manually duplicated here."""
- num_gpu_blocks_override: Optional[int] = None
- """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
- if specified. Does nothing if `None`. Used for testing preemption."""
- sliding_window: Optional[int] = None
- """Sliding window size for the KV cache. This is primarily set in
- `ModelConfig` and that value should be manually duplicated here."""
- enable_prefix_caching: Optional[bool] = None
- """Whether to enable prefix caching. Disabled by default for V0. Enabled by
- default for V1."""
- prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
- """Set the hash algorithm for prefix caching:\n
- - "builtin" is Python's built-in hash.\n
- - "sha256" is collision resistant but with certain overheads.
- This option uses Pickle for object serialization before hashing.\n
- - "sha256_cbor_64bit" provides a reproducible, cross-language compatible
- hash. It serializes objects using canonical CBOR and hashes them with
- SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
- digest."""
- cpu_offload_gb: float = 0
- """The space in GiB to offload to CPU, per GPU. Default is 0, which means
- no offloading. Intuitively, this argument can be seen as a virtual way to
- increase the GPU memory size. For example, if you have one 24 GB GPU and
- set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
- load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
- Note that this requires fast CPU-GPU interconnect, as part of the model is
- loaded from CPU memory to GPU memory on the fly in each model forward pass.
- """
- calculate_kv_scales: bool = False
- """This enables dynamic calculation of `k_scale` and `v_scale` when
- kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
- checkpoint if available. Otherwise, the scales will default to 1.0."""
- cpu_kvcache_space_bytes: Optional[int] = None
- """(CPU backend only) CPU key-value cache space."""
- mamba_page_size_padded: Optional[int] = None
- """ Optional override for mamba page size; used by hybrid mamba/attention
- models to ensure exact alignment with attention page size."""
-
- # Will be set after profiling.
- num_gpu_blocks: Optional[int] = field(default=None, init=False)
- """The number of blocks to allocate for GPU memory."""
- num_cpu_blocks: Optional[int] = field(default=None, init=False)
- """The number of blocks to allocate for CPU memory."""
-
- kv_sharing_fast_prefill: bool = False
- """This feature is work in progress and no prefill optimization takes place
- with this flag enabled currently.
-
- In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
- some layers can skip tokens corresponding to prefill. This flag enables
- attention metadata for eligible layers to be overriden with metadata
- necessary for implementating this optimization in some models (e.g. Gemma3n)
- """
-
- def compute_hash(self) -> str:
- """
- WARNING: Whenever a new field is added to this config,
- ensure that it is included in the factors list if
- it affects the computation graph.
-
- Provide a hash that uniquely identifies all the configs
- that affect the structure of the computation
- graph from input ids/embeddings to the final hidden states,
- excluding anything before input ids/embeddings and after
- the final hidden states.
- """
- factors: list[Any] = []
- factors.append(self.cache_dtype)
- # `cpu_offload_gb` does not use `torch.compile` yet.
- hash_str = hashlib.md5(str(factors).encode(),
- usedforsecurity=False).hexdigest()
- return hash_str
-
- def __post_init__(self) -> None:
- self.swap_space_bytes = self.swap_space * GiB_bytes
-
- self._verify_cache_dtype()
- self._verify_prefix_caching()
-
- def metrics_info(self):
- # convert cache_config to dict(key: str, value: str) for prometheus
- # metrics info
- return {key: str(value) for key, value in self.__dict__.items()}
-
- @model_validator(mode='after')
- def _verify_args(self) -> Self:
- if self.cpu_offload_gb < 0:
- raise ValueError("CPU offload space must be non-negative"
- f", but got {self.cpu_offload_gb}")
-
- if self.gpu_memory_utilization > 1.0:
- raise ValueError(
- "GPU memory utilization must be less than 1.0. Got "
- f"{self.gpu_memory_utilization}.")
-
- if self.kv_sharing_fast_prefill:
- logger.warning_once(
- "--kv-sharing-fast-prefill is currently work in progress "
- "and not functional yet (i.e. no prefill savings)")
-
- return self
-
- def _verify_cache_dtype(self) -> None:
- if self.cache_dtype == "auto":
- pass
- elif self.cache_dtype in get_args(CacheDType):
- logger.info(
- "Using fp8 data type to store kv cache. It reduces the GPU "
- "memory footprint and boosts the performance. "
- "Meanwhile, it may cause accuracy drop without a proper "
- "scaling factor.")
- else:
- raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
-
- def _verify_prefix_caching(self) -> None:
- if not self.enable_prefix_caching:
- return
-
- if self.sliding_window is not None and not envs.VLLM_USE_V1:
- raise NotImplementedError(
- "Prefix caching is not supported with sliding window. "
- "Run with --disable-sliding-window to use prefix caching.")
-
- if (self.enable_prefix_caching and self.prefix_caching_hash_algo
- not in get_args(PrefixCachingHashAlgo)):
- raise ValueError(
- "Unknown prefix caching hash algorithm: "
- f"{self.prefix_caching_hash_algo}. Must be one of "
- f"{get_args(PrefixCachingHashAlgo)}.")
-
- def verify_with_parallel_config(
- self,
- parallel_config: "ParallelConfig",
- ) -> None:
- total_cpu_memory = get_cpu_memory()
- # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
- # group are in the same node. However, the GPUs may span multiple nodes.
- num_gpus_per_node = parallel_config.tensor_parallel_size
- cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
-
- msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
- f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
- "is allocated for the swap space.")
- if cpu_memory_usage > 0.7 * total_cpu_memory:
- raise ValueError("Too large swap space. " + msg)
- elif cpu_memory_usage > 0.4 * total_cpu_memory:
- logger.warning("Possibly too large swap space. %s", msg)
-
-
@config
@dataclass
class LoadConfig:
@@ -2072,659 +1825,6 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"]
-DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
-
-
-@config
-@dataclass
-class ParallelConfig:
- """Configuration for the distributed execution."""
-
- pipeline_parallel_size: int = 1
- """Number of pipeline parallel groups."""
- tensor_parallel_size: int = 1
- """Number of tensor parallel groups."""
- data_parallel_size: int = 1
- """Number of data parallel groups. MoE layers will be sharded according to
- the product of the tensor parallel size and data parallel size."""
- data_parallel_size_local: int = 1
- """Number of local data parallel groups."""
- data_parallel_rank: int = 0
- """Rank of the data parallel group."""
- data_parallel_rank_local: Optional[int] = None
- """Local rank of the data parallel group,
- set only in SPMD mode."""
- data_parallel_master_ip: str = "127.0.0.1"
- """IP of the data parallel master."""
- data_parallel_rpc_port: int = 29550
- """Port for data parallel messaging."""
- data_parallel_master_port: int = 29500
- """Port of the data parallel master."""
- data_parallel_backend: str = "mp"
- """Backend to use for data parallel, either "mp" or "ray"."""
- data_parallel_external_lb: bool = False
- """Whether to use "external" DP LB mode. Applies only to online serving
- and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
- wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
- is provided explicitly to vllm serve."""
- data_parallel_hybrid_lb: bool = False
- """Whether to use "hybrid" DP LB mode. Applies only to online serving
- and when data_parallel_size > 0. Enables running an AsyncLLM
- and API server on a "per-node" basis where vLLM load balances
- between local data parallel ranks, but an external LB balances
- between vLLM nodes/replicas. Set explicitly in conjunction with
- --data-parallel-start-rank."""
- enable_expert_parallel: bool = False
- """Use expert parallelism instead of tensor parallelism for MoE layers."""
- enable_eplb: bool = False
- """Enable expert parallelism load balancing for MoE layers."""
- num_redundant_experts: int = 0
- """Number of redundant experts to use for expert parallelism."""
- eplb_window_size: int = 1000
- """Window size for expert load recording."""
- eplb_step_interval: int = 3000
- """
- Interval for rearranging experts in expert parallelism.
-
- Note that if this is greater than the EPLB window size, only the metrics
- of the last `eplb_window_size` steps will be used for rearranging experts.
- """
- eplb_log_balancedness: bool = False
- """
- Log the balancedness each step of expert parallelism.
- This is turned off by default since it will cause communication overhead.
- """
-
- max_parallel_loading_workers: Optional[int] = None
- """Maximum number of parallel loading workers when loading model
- sequentially in multiple batches. To avoid RAM OOM when using tensor
- parallel and large models."""
-
- disable_custom_all_reduce: bool = False
- """Disable the custom all-reduce kernel and fall back to NCCL."""
-
- ray_workers_use_nsight: bool = False
- """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
-
- ray_runtime_env: Optional["RuntimeEnv"] = None
- """Ray runtime environment to pass to distributed workers."""
-
- placement_group: Optional["PlacementGroup"] = None
- """ray distributed model workers placement group."""
-
- distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
- type["ExecutorBase"]]] = None
- """Backend to use for distributed model
- workers, either "ray" or "mp" (multiprocessing). If the product
- of pipeline_parallel_size and tensor_parallel_size is less than
- or equal to the number of GPUs available, "mp" will be used to
- keep processing on a single host. Otherwise, this will default
- to "ray" if Ray is installed and fail otherwise. Note that tpu
- only support Ray for distributed inference."""
-
- worker_cls: str = "auto"
- """The full name of the worker class to use. If "auto", the worker class
- will be determined based on the platform."""
- sd_worker_cls: str = "auto"
- """The full name of the worker class to use for speculative decoding.
- If "auto", the worker class will be determined based on the platform."""
- worker_extension_cls: str = ""
- """The full name of the worker extension class to use. The worker extension
- class is dynamically inherited by the worker class. This is used to inject
- new attributes and methods to the worker class for use in collective_rpc
- calls."""
-
- world_size: int = field(init=False)
- """world_size is TPxPP, it affects the number of workers we create."""
-
- rank: int = 0
- """Global rank in distributed setup."""
-
- enable_multimodal_encoder_data_parallel: bool = False
- """ Use data parallelism instead of tensor parallelism for vision encoder.
- Only support LLama4 for now"""
-
- @property
- def world_size_across_dp(self) -> int:
- """world_size_across_dp is TPxPPxDP, it is the size of the world
- including data parallelism."""
- return self.world_size * self.data_parallel_size
-
- def get_next_dp_init_port(self) -> int:
- """
- We might need to initialize process groups in multiple
- processes that is related to data parallelism,
- e.g. both in the worker and in the engine, which
- can live in different processes. To avoid port conflicts, we
- increment the port number each time we need to initialize a
- new process group related to data parallelism.
- """
- answer = self.data_parallel_master_port
- self.data_parallel_master_port += 1
- return answer
-
- def stateless_init_dp_group(self) -> "ProcessGroup":
- # NOTE: In high-concurrency scenarios multiple processes
- # can pick the same (currently free) port through a race
- # condition when calling `get_open_port()`. When the first
- # process binds the port the others will subsequently fail
- # with `torch.distributed.DistNetworkError: EADDRINUSE`.
- # To make the initialization more robust we retry a few times
- # with a fresh port whenever this specific error is observed.
- from torch.distributed import DistNetworkError
-
- from vllm.distributed.utils import (
- stateless_init_torch_distributed_process_group)
-
- max_retries = 5
- last_exc: Optional[Exception] = None
- for _ in range(max_retries):
- try:
- # use gloo since the engine process might not have cuda device
- return stateless_init_torch_distributed_process_group(
- self.data_parallel_master_ip,
- self.get_next_dp_init_port(),
- self.data_parallel_rank,
- self.data_parallel_size,
- backend="gloo")
- except DistNetworkError as e:
- # We only want to retry when the root cause is EADDRINUSE.
- if "EADDRINUSE" in str(e):
- logger.warning(
- "Address already in use. Retrying with a new port.")
- last_exc = e
- continue # try again with a new port
- raise e
-
- # If we get here all retries have failed.
- assert last_exc is not None
- raise last_exc
-
- @staticmethod
- def has_unfinished_dp(dp_group: "ProcessGroup",
- has_unfinished: bool) -> bool:
- tensor = torch.tensor([has_unfinished],
- dtype=torch.int32,
- device="cpu")
- # dp rank 0: has_unfinished_seqs=True
- # dp rank 1: has_unfinished_seqs=False
- # aggregated: has_unfinished_seqs=True
- # so this is an OR operation, i.e. MAX in integers
- torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
- aggregated_has_unfinished = bool(tensor.item())
- return aggregated_has_unfinished
-
- @staticmethod
- def sync_kv_cache_memory_size(dp_group: "ProcessGroup",
- kv_cache_memory: int) -> int:
- if kv_cache_memory == -1:
- kv_cache_memory = torch.iinfo(torch.int64).max
- tensor = torch.tensor([kv_cache_memory],
- dtype=torch.int64,
- device="cpu")
- # we cannot use broadcast for stateless dp group since it depends
- # on global rank
- torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
- return tensor.item()
-
- def compute_hash(self):
- """
- Provide a hash that uniquely identifies all the configs
- that affect the structure of the computation
- graph from input ids/embeddings to the final hidden states,
- excluding anything before input ids/embeddings and after
- the final hidden states.
- """
- factors: list[Any] = []
- factors.append(self.pipeline_parallel_size)
- factors.append(self.tensor_parallel_size)
- factors.append(self.enable_expert_parallel)
- factors.append(self.data_parallel_size)
- factors.append(envs.VLLM_ALL2ALL_BACKEND)
- return hashlib.sha256(str(factors).encode()).hexdigest()
-
- def __post_init__(self) -> None:
- self.world_size = self.pipeline_parallel_size * \
- self.tensor_parallel_size
-
- if self.data_parallel_size_local > self.data_parallel_size:
- raise ValueError(
- f"data_parallel_size_local ({self.data_parallel_size_local}) "
- f"must be <= data_parallel_size ({self.data_parallel_size})")
-
- if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
- # Data parallel was specified in the engine args.
- self.data_parallel_master_port = get_open_port()
-
- if not (0 <= self.data_parallel_rank < self.data_parallel_size):
- raise ValueError(
- f"data_parallel_rank ({self.data_parallel_rank})"
- f" must be in the range [0, {self.data_parallel_size})")
- else:
- # Otherwise fall back to env vars (e.g. for offline SPMD case).
- self.data_parallel_size = envs.VLLM_DP_SIZE
- self.data_parallel_rank = envs.VLLM_DP_RANK
- self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
- self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
- self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
-
- if self.data_parallel_external_lb:
- raise ValueError("data_parallel_external_lb can only "
- "be set when data_parallel_size > 1")
-
- if self.distributed_executor_backend == "external_launcher":
- import os
- os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
- logger.info("Disabling V1 multiprocessing for external launcher.")
-
- if self.enable_eplb:
- if not current_platform.is_cuda():
- raise ValueError(
- "Expert parallelism load balancing is only supported on "
- "CUDA devices now.")
- if self.num_redundant_experts < 0:
- raise ValueError(
- "num_redundant_experts must be non-negative, but got "
- f"{self.num_redundant_experts}.")
- if not self.enable_expert_parallel:
- raise ValueError(
- "enable_expert_parallel must be True to use EPLB.")
- if self.tensor_parallel_size * self.data_parallel_size <= 1:
- raise ValueError(
- "EPLB requires tensor_parallel_size or data_parallel_size "
- f"to be greater than 1, but got "
- f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
- )
- else:
- if self.num_redundant_experts != 0:
- raise ValueError(
- "num_redundant_experts should be used with EPLB."
- f"{self.num_redundant_experts}.")
- if self.distributed_executor_backend is None and self.world_size > 1:
- # We use multiprocessing by default if world_size fits on the
- # current node and we aren't in a ray placement group.
-
- from vllm.executor import ray_utils
- backend: DistributedExecutorBackend = "mp"
- ray_found = ray_utils.ray_is_available()
- if current_platform.is_neuron():
- # neuron uses single process to control multiple devices
- backend = "uni"
- elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
- backend = "uni"
- elif (current_platform.is_cuda()
- and cuda_device_count_stateless() < self.world_size):
- if not ray_found:
- raise ValueError("Unable to load Ray: "
- f"{ray_utils.ray_import_err}. Ray is "
- "required for multi-node inference, "
- "please install Ray with `pip install "
- "ray`.")
- backend = "ray"
- elif self.data_parallel_backend == "ray":
- logger.info("Using ray distributed inference because "
- "data_parallel_backend is ray")
- backend = "ray"
- elif ray_found:
- if self.placement_group:
- backend = "ray"
- else:
- from ray import is_initialized as ray_is_initialized
- if ray_is_initialized():
- from ray.util import get_current_placement_group
- if get_current_placement_group():
- backend = "ray"
- self.distributed_executor_backend = backend
- logger.debug("Defaulting to use %s for distributed inference",
- backend)
-
- if self.distributed_executor_backend is None and self.world_size == 1:
- self.distributed_executor_backend = "uni"
-
- @property
- def use_ray(self) -> bool:
- return self.distributed_executor_backend == "ray" or (
- isinstance(self.distributed_executor_backend, type)
- and self.distributed_executor_backend.uses_ray)
-
- @model_validator(mode='after')
- def _verify_args(self) -> Self:
- # Lazy import to avoid circular import
- from vllm.executor.executor_base import ExecutorBase
- from vllm.platforms import current_platform
- if self.distributed_executor_backend not in (
- "ray", "mp", "uni",
- "external_launcher", None) and not (isinstance(
- self.distributed_executor_backend, type) and issubclass(
- self.distributed_executor_backend, ExecutorBase)):
- raise ValueError(
- "Unrecognized distributed executor backend "
- f"{self.distributed_executor_backend}. Supported "
- "values are 'ray', 'mp' 'uni', 'external_launcher' or"
- " custom ExecutorBase subclass.")
- if self.use_ray:
- from vllm.executor import ray_utils
- ray_utils.assert_ray_available()
-
- if not current_platform.use_custom_allreduce():
- self.disable_custom_all_reduce = True
- logger.debug(
- "Disabled the custom all-reduce kernel because it is not "
- "supported on current platform.")
- if self.ray_workers_use_nsight and not self.use_ray:
- raise ValueError("Unable to use nsight profiling unless workers "
- "run with Ray.")
-
- return self
-
-
-PreemptionMode = Literal["swap", "recompute"]
-SchedulerPolicy = Literal["fcfs", "priority"]
-
-
-@config
-@dataclass
-class SchedulerConfig:
- """Scheduler configuration."""
-
- runner_type: RunnerType = "generate"
- """The runner type to launch for the model."""
-
- max_num_batched_tokens: SkipValidation[int] = None # type: ignore
- """Maximum number of tokens to be processed in a single iteration.
-
- This config has no static default. If left unspecified by the user, it will
- be set in `EngineArgs.create_engine_config` based on the usage context."""
-
- max_num_seqs: SkipValidation[int] = None # type: ignore
- """Maximum number of sequences to be processed in a single iteration.
-
- This config has no static default. If left unspecified by the user, it will
- be set in `EngineArgs.create_engine_config` based on the usage context."""
-
- max_model_len: SkipValidation[int] = None # type: ignore
- """Maximum length of a sequence (including prompt and generated text). This
- is primarily set in `ModelConfig` and that value should be manually
- duplicated here."""
-
- max_num_partial_prefills: int = 1
- """For chunked prefill, the maximum number of sequences that can be
- partially prefilled concurrently."""
-
- max_long_partial_prefills: int = 1
- """For chunked prefill, the maximum number of prompts longer than
- long_prefill_token_threshold that will be prefilled concurrently. Setting
- this less than max_num_partial_prefills will allow shorter prompts to jump
- the queue in front of longer prompts in some cases, improving latency."""
-
- long_prefill_token_threshold: int = 0
- """For chunked prefill, a request is considered long if the prompt is
- longer than this number of tokens."""
-
- num_lookahead_slots: int = 0
- """The number of slots to allocate per sequence per
- step, beyond the known token ids. This is used in speculative
- decoding to store KV activations of tokens which may or may not be
- accepted.
-
- NOTE: This will be replaced by speculative config in the future; it is
- present to enable correctness tests until then."""
-
- cuda_graph_sizes: list[int] = field(default_factory=list)
- """Cuda graph capture sizes
- 1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
- 2. if one value is provided, then the capture list would follow the
- pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
- 3. more than one value (e.g. 1 2 128) is provided, then the capture list
- will follow the provided list."""
-
- delay_factor: float = 0.0
- """Apply a delay (of delay factor multiplied by previous
- prompt latency) before scheduling next prompt."""
-
- enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
- """If True, prefill requests can be chunked based
- on the remaining max_num_batched_tokens."""
-
- is_multimodal_model: bool = False
- """True if the model is multimodal."""
-
- # TODO (ywang96): Make this configurable.
- max_num_encoder_input_tokens: int = field(init=False)
- """Multimodal encoder compute budget, only used in V1.
-
- NOTE: This is not currently configurable. It will be overridden by
- max_num_batched_tokens in case max multimodal embedding size is larger."""
-
- # TODO (ywang96): Make this configurable.
- encoder_cache_size: int = field(init=False)
- """Multimodal encoder cache size, only used in V1.
-
- NOTE: This is not currently configurable. It will be overridden by
- max_num_batched_tokens in case max multimodal embedding size is larger."""
-
- preemption_mode: Optional[PreemptionMode] = None
- """Whether to perform preemption by swapping or
- recomputation. If not specified, we determine the mode as follows:
- We use recomputation by default since it incurs lower overhead than
- swapping. However, when the sequence group has multiple sequences
- (e.g., beam search), recomputation is not currently supported. In
- such a case, we use swapping instead."""
-
- num_scheduler_steps: int = 1
- """Maximum number of forward steps per scheduler call."""
-
- multi_step_stream_outputs: bool = True
- """If False, then multi-step will stream outputs at the end of all steps"""
-
- send_delta_data: bool = False
- """Private API. If used, scheduler sends delta data to
- workers instead of an entire data. It should be enabled only
- when SPMD worker architecture is enabled. I.e.,
- VLLM_USE_RAY_SPMD_WORKER=1"""
-
- policy: SchedulerPolicy = "fcfs"
- """The scheduling policy to use:\n
- - "fcfs" means first come first served, i.e. requests are handled in order
- of arrival.\n
- - "priority" means requests are handled based on given priority (lower
- value means earlier handling) and time of arrival deciding any ties)."""
-
- chunked_prefill_enabled: bool = field(init=False)
- """True if chunked prefill is enabled."""
-
- disable_chunked_mm_input: bool = False
- """If set to true and chunked prefill is enabled, we do not want to
- partially schedule a multimodal item. Only used in V1
- This ensures that if a request has a mixed prompt
- (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
- some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
- it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
-
- # scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
- # or "mod.custom_class".
- scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
- """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
- default scheduler. Can be a class directly or the path to a class of form
- "mod.custom_class"."""
-
- disable_hybrid_kv_cache_manager: bool = False
- """If set to True, KV cache manager will allocate the same size of KV cache
- for all attention layers even if there are multiple type of attention layers
- like full attention and sliding window attention.
- """
-
- async_scheduling: bool = False
- """EXPERIMENTAL: If set to True, perform async scheduling. This may help
- reduce the CPU overheads, leading to better latency and throughput. However,
- async scheduling is currently not supported with some features such as
- structured outputs, speculative decoding, and pipeline parallelism.
- """
-
- def compute_hash(self) -> str:
- """
- WARNING: Whenever a new field is added to this config,
- ensure that it is included in the factors list if
- it affects the computation graph.
-
- Provide a hash that uniquely identifies all the configs
- that affect the structure of the computation
- graph from input ids/embeddings to the final hidden states,
- excluding anything before input ids/embeddings and after
- the final hidden states.
- """
- # no factors to consider.
- # this config will not affect the computation graph.
- factors: list[Any] = []
- hash_str = hashlib.md5(str(factors).encode(),
- usedforsecurity=False).hexdigest()
- return hash_str
-
- def __post_init__(self) -> None:
- if self.max_model_len is None:
- self.max_model_len = 8192
-
- if self.max_num_seqs is None:
- self.max_num_seqs = 128
-
- if self.max_num_batched_tokens is None:
- if self.enable_chunked_prefill:
- if self.num_scheduler_steps > 1:
- # Multi-step Chunked-Prefill doesn't allow prompt-chunking
- # for now. Have max_num_batched_tokens set to max_model_len
- # so we don't reject sequences on account of a short
- # max_num_batched_tokens.
- self.max_num_batched_tokens = max(
- self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
- else:
- self.max_num_batched_tokens = (
- DEFAULT_MAX_NUM_BATCHED_TOKENS)
- else:
- # If max_model_len is too short, use
- # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
- # for higher throughput.
- self.max_num_batched_tokens = max(
- self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
-
- if self.runner_type == "pooling":
- # Choose specific value for higher throughput
- self.max_num_batched_tokens = max(
- self.max_num_batched_tokens,
- POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
- )
- if self.is_multimodal_model:
- # The value needs to be at least the number of multimodal tokens
- self.max_num_batched_tokens = max(
- self.max_num_batched_tokens,
- MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
- )
-
- # When using default settings,
- # Ensure max_num_batched_tokens does not exceed model limit.
- # Some models (e.g., Whisper) have embeddings tied to max length.
- self.max_num_batched_tokens = min(
- self.max_num_seqs * self.max_model_len,
- self.max_num_batched_tokens)
-
- self.max_num_encoder_input_tokens = self.max_num_batched_tokens
- self.encoder_cache_size = self.max_num_batched_tokens
-
- if self.enable_chunked_prefill:
- logger.info(
- "Chunked prefill is enabled with max_num_batched_tokens=%d.",
- self.max_num_batched_tokens)
-
- self.chunked_prefill_enabled = self.enable_chunked_prefill
- if self.max_num_partial_prefills > 1:
- if self.long_prefill_token_threshold == 0:
- self.long_prefill_token_threshold = int(self.max_model_len *
- 0.04)
-
- logger.info(
- "Concurrent partial prefills enabled with "
- "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
- "long_prefill_token_threshold=%d",
- self.max_num_partial_prefills, self.max_long_partial_prefills,
- self.long_prefill_token_threshold)
-
- # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
- # This avoids OOM in tight memory scenarios with small max_num_seqs,
- # and prevents capture of many large graphs (>512) that would greatly
- # increase startup time with limited performance benefit.
- if not self.cuda_graph_sizes:
- self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
-
- if self.async_scheduling:
- self.scheduler_cls = (
- "vllm.v1.core.sched.async_scheduler.AsyncScheduler")
-
- @model_validator(mode='after')
- def _verify_args(self) -> Self:
- if (self.max_num_batched_tokens < self.max_model_len
- and not self.chunked_prefill_enabled):
- raise ValueError(
- f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
- f"smaller than max_model_len ({self.max_model_len}). "
- "This effectively limits the maximum sequence length to "
- "max_num_batched_tokens and makes vLLM reject longer "
- "sequences. Please increase max_num_batched_tokens or "
- "decrease max_model_len.")
-
- if self.max_num_batched_tokens < self.max_num_seqs:
- raise ValueError(
- f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
- "be greater than or equal to max_num_seqs "
- f"({self.max_num_seqs}).")
-
- if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
- logger.warning(
- "max_num_batched_tokens (%d) exceeds max_num_seqs "
- "* max_model_len (%d). This may lead to unexpected behavior.",
- self.max_num_batched_tokens,
- self.max_num_seqs * self.max_model_len)
-
- if self.num_lookahead_slots < 0:
- raise ValueError(
- "num_lookahead_slots "
- f"({self.num_lookahead_slots}) must be greater than or "
- "equal to 0.")
-
- if self.num_scheduler_steps < 1:
- raise ValueError(
- "num_scheduler_steps "
- f"({self.num_scheduler_steps}) must be greater than or "
- "equal to 1.")
-
- if self.max_num_partial_prefills < 1:
- raise ValueError(
- f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
- "must be greater than or equal to 1.")
- elif self.max_num_partial_prefills > 1:
- if not self.chunked_prefill_enabled:
- raise ValueError("Chunked prefill must be enabled to set "
- "max_num_partial_prefills > 1.")
-
- if self.long_prefill_token_threshold > self.max_model_len:
- raise ValueError(
- "long_prefill_token_threshold "
- f"({self.long_prefill_token_threshold}) cannot be greater "
- f"than the max_model_len ({self.max_model_len}).")
-
- if (self.max_long_partial_prefills
- < 1) or (self.max_long_partial_prefills
- > self.max_num_partial_prefills):
- raise ValueError(
- f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
- "must be greater than or equal to 1 and less than or equal to "
- f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
-
- return self
-
- @property
- def is_multi_step(self) -> bool:
- return self.num_scheduler_steps > 1
-
-
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"]
@@ -3234,13 +2334,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
- from vllm.transformers_utils.configs import SpeculatorsConfig
-
- eagle3_target_supported = ["llama"]
- if self.draft_model_config and isinstance(
- self.draft_model_config.hf_config, SpeculatorsConfig):
- eagle3_target_supported.append("qwen")
-
+ eagle3_target_supported = ["llama", "qwen"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type
@@ -3693,7 +2787,7 @@ def _get_and_verify_max_len(
tokenizer_config: Optional[dict],
max_model_len: Optional[int],
disable_sliding_window: bool,
- sliding_window_len: Optional[Union[int, list[Optional[int]]]],
+ sliding_window: Optional[int],
spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None,
) -> int:
@@ -3732,13 +2826,10 @@ def _get_and_verify_max_len(
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
- if disable_sliding_window and sliding_window_len is not None:
-
- sliding_window_len_min = get_min_sliding_window(sliding_window_len)
- max_len_key = "sliding_window" \
- if sliding_window_len_min < derived_max_model_len else max_len_key
- derived_max_model_len = min(derived_max_model_len,
- sliding_window_len_min)
+ if (disable_sliding_window and sliding_window is not None
+ and sliding_window < derived_max_model_len):
+ max_len_key = "sliding_window"
+ derived_max_model_len = sliding_window
# Consider model_max_length in tokenizer_config
if tokenizer_config:
@@ -3839,14 +2930,6 @@ def _get_and_verify_max_len(
return int(max_model_len)
-def get_min_sliding_window(
- sliding_window: Union[int, list[Optional[int]]]) -> int:
- if isinstance(sliding_window, list):
- return min(s for s in sliding_window if s is not None)
-
- return sliding_window
-
-
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, list[str]]]):
"""
@@ -4154,421 +3237,6 @@ class KVEventsConfig:
"""
-class CompilationLevel:
- # constants for the levels of the compilation process
- NO_COMPILATION = 0
- DYNAMO_AS_IS = 1
- DYNAMO_ONCE = 2
- PIECEWISE = 3
-
-
-@config
-@dataclass
-class PassConfig:
- """Configuration for custom Inductor passes.
-
- This is separate from general `CompilationConfig` so that inductor passes
- don't all have access to full configuration - that would create a cycle as
- the `PassManager` is set as a property of config."""
-
- enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
- """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
- enable_attn_fusion: bool = False
- """Whether to enable the custom attention+quant fusion pass."""
- enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
- """Whether to enable the custom no-op elimination pass."""
- enable_sequence_parallelism: bool = False
- """Whether to enable sequence parallelism."""
- enable_async_tp: bool = False
- """Whether to enable async TP."""
- enable_fi_allreduce_fusion: bool = False
- """Whether to enable flashinfer allreduce fusion."""
- fi_allreduce_fusion_max_token_num: int = 16384
- """Max number of tokens to used in flashinfer allreduce fusion."""
-
- # TODO(luka) better pass enabling system.
-
- def uuid(self):
- """
- Produces a hash unique to the pass configuration.
- Any new fields that affect compilation should be added to the hash.
- Any future fields that don't affect compilation should be excluded.
- """
- return InductorPass.hash_dict(asdict(self))
-
- def __post_init__(self) -> None:
- if not self.enable_noop:
- if self.enable_fusion:
- logger.warning_once(
- "Fusion enabled but reshape elimination disabled. "
- "RMSNorm/SiluMul + quant (fp8) fusion might not work")
- if self.enable_attn_fusion:
- logger.warning_once(
- "Fusion enabled but reshape elimination disabled. "
- "Attention + quant (fp8) fusion might not work")
-
-
-@config
-@dataclass
-class CompilationConfig:
- """Configuration for compilation. It has three parts:
-
- - Top-level Compilation control:
- - [`level`][vllm.config.CompilationConfig.level]
- - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
- - [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
- - [`backend`][vllm.config.CompilationConfig.backend]
- - [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
- - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- - CudaGraph capture:
- - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- - [`cudagraph_capture_sizes`]
- [vllm.config.CompilationConfig.cudagraph_capture_sizes]
- - [`cudagraph_num_of_warmups`]
- [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- - [`cudagraph_copy_inputs`]
- [vllm.config.CompilationConfig.cudagraph_copy_inputs]
- - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
- - Inductor compilation:
- - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- - [`inductor_compile_config`]
- [vllm.config.CompilationConfig.inductor_compile_config]
- - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
- - custom inductor passes
-
- Why we have different sizes for cudagraph and inductor:
- - cudagraph: a cudagraph captured for a specific size can only be used
- for the same size. We need to capture all the sizes we want to use.
- - inductor: a graph compiled by inductor for a general shape can be used
- for different sizes. Inductor can also compile for specific sizes,
- where it can have more information to optimize the graph with fully
- static shapes. However, we find the general shape compilation is
- sufficient for most cases. It might be beneficial to compile for
- certain small batchsizes, where inductor is good at optimizing.
- """
- # Top-level Compilation control
- level: Optional[int] = None
- """The level of compilation:
-
- - None: If None, we will select the default compilation level.
- For V1 engine this is 3, for V0 engine this is 0.
- - 0: no compilation.
- - 1: dynamo as is.
- - 2: dynamo once.
- - 3: piecewise compilation."""
- debug_dump_path: str = ""
- """The path to dump the debug information."""
- cache_dir: str = ""
- """The directory to store the compiled graph, to accelerate Inductor
- compilation. By default, it will use model-related information to generate
- a cache directory."""
- backend: str = ""
- """The backend for compilation. It needs to be a string:
-
- - "" (empty string): use the default backend.
- - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- - "full.module.name": a qualified name which can be used to import the
-
- backend function.
- We use string to avoid serialization issues when using compilation in a
- distributed setting. When the compilation level is 1 or 2, the backend is
- used for the compilation directly (it sees the whole graph). When the
- compilation level is 3, the backend is used for the piecewise compilation
- (it sees a part of the graph)."""
- custom_ops: list[str] = field(default_factory=list)
- """Fine-grained control over which custom ops to enable/disable. Use 'all'
- to enable all, 'none' to disable all. Also specify a list of custom op
- names to enable (prefixed with a '+'), or disable (prefixed with a '-').
- Examples:
-
- - 'all,-op1' to enable all except op1
- - 'none,+op1,+op2' to enable only op1 and op2
-
- By default, all custom ops are enabled when running without Inductor and
- disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
- Inductor generates (fused) Triton kernels for disabled custom ops."""
- splitting_ops: list[str] = field(default_factory=list)
- """A list of ops to split the full graph into subgraphs, used in piecewise
- compilation."""
-
- # Inductor capture
- use_inductor: bool = True
- """Whether to use inductor compilation:
-
- - False: inductor compilation is not used. graph runs in eager
- (custom_ops enabled by default).
- - True: inductor compilation is used (custom_ops disabled by default).
- One graph for symbolic shape and one graph per size in compile_sizes
- are compiled using configurations in inductor_compile_config.
-
- This setting is ignored if level1."""
-
- def compute_hash(self) -> str:
- """
- WARNING: Whenever a new field is added to this config,
- ensure that it is included in the factors list if
- it affects the computation graph.
-
- Provide a hash that uniquely identifies all the configs
- that affect the structure of the computation
- graph from input ids/embeddings to the final hidden states,
- excluding anything before input ids/embeddings and after
- the final hidden states.
- """
- factors: list[Any] = []
- factors.append(self.level)
- factors.append(self.backend)
- factors.append(self.custom_ops)
- factors.append(self.splitting_ops)
- factors.append(self.use_inductor)
- factors.append(self.inductor_compile_config)
- factors.append(self.inductor_passes)
- factors.append(self.pass_config.uuid())
- return hashlib.sha256(str(factors).encode()).hexdigest()
-
- def __repr__(self) -> str:
- exclude = {
- "static_forward_context": True,
- "enabled_custom_ops": True,
- "disabled_custom_ops": True,
- "compilation_time": True,
- "bs_to_padded_graph_size": True,
- "traced_files": True,
- "inductor_compile_config": {
- "post_grad_custom_post_pass": True,
- },
- }
-
- # exclude default attr in pass_config
- pass_config_exclude = {}
- for attr, default_val in vars(PassConfig()).items():
- if getattr(self.pass_config, attr) == default_val:
- pass_config_exclude[attr] = True
- if pass_config_exclude:
- exclude["pass_config"] = pass_config_exclude
-
- # The cast to string is necessary because Pydantic is mocked in docs
- # builds and sphinx-argparse doesn't know the return type of decode()
- return str(
- TypeAdapter(CompilationConfig).dump_json(
- self,
- exclude=exclude, # type: ignore[arg-type]
- exclude_unset=True).decode())
-
- __str__ = __repr__
-
- @classmethod
- def from_cli(cls, cli_value: str) -> "CompilationConfig":
- """Parse the CLI value for the compilation config.
- -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
- """
- return TypeAdapter(CompilationConfig).validate_json(cli_value)
-
- def __post_init__(self) -> None:
- count_none = self.custom_ops.count("none")
- count_all = self.custom_ops.count("all")
- assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
-
- # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
- # 1. A bug in PyTorch, fixed in 2.7:
- # https://github.com/pytorch/pytorch/issues/147924
- # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
- # work with V2. Addressing this will take extra engineering effort
- # and it is not yet a priority. RFC here:
- # https://github.com/vllm-project/vllm/issues/14703
-
- if is_torch_equal_or_newer("2.6"):
- KEY = 'enable_auto_functionalized_v2'
- if KEY not in self.inductor_compile_config:
- self.inductor_compile_config[KEY] = False
-
- for k, v in self.inductor_passes.items():
- if not isinstance(v, str):
- assert callable(v), (
- f"pass {k} should be callable or a qualified name")
- self.inductor_compile_config[k] = v if isinstance(
- v, InductorPass) else CallableInductorPass(v)
- continue
-
- # resolve function from qualified name
- names = v.split(".")
- module = ".".join(names[:-1])
- func_name = names[-1]
- func = __import__(module).__dict__[func_name]
- self.inductor_compile_config[k] = func if isinstance(
- func, InductorPass) else CallableInductorPass(func)
-
- if isinstance(self.pass_config, dict):
- self.pass_config = PassConfig(**self.pass_config)
-
- def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
- if self.level == CompilationLevel.NO_COMPILATION:
- raise ValueError("No compilation level is set.")
-
- from torch._dynamo.backends.registry import list_backends
- torch_backends = list_backends(exclude_tags=tuple())
- if self.level in [
- CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
- ]:
- if self.backend == "":
- return "eager"
- if self.backend in torch_backends:
- return self.backend
- return resolve_obj_by_qualname(self.backend)
-
- # TODO: pass user-specified backend to piecewise compilation
- # merge with the config use_inductor
- assert self.level == CompilationLevel.PIECEWISE
-
- from vllm.compilation.backends import VllmBackend
- return VllmBackend(vllm_config)
-
- def init_with_cudagraph_sizes(self,
- cudagraph_capture_sizes: list[int]) -> None:
- """To complete the initialization of config,
- we need to know the cudagraph sizes."""
-
- if self.cudagraph_capture_sizes is None:
- self.cudagraph_capture_sizes = cudagraph_capture_sizes
- else:
- # de-duplicate the sizes provided by the config
- dedup_sizes = list(set(self.cudagraph_capture_sizes))
- if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
- logger.info(("cudagraph sizes specified by model runner"
- " %s is overridden by config %s"),
- cudagraph_capture_sizes, dedup_sizes)
- self.cudagraph_capture_sizes = dedup_sizes
-
- computed_compile_sizes = []
- if self.compile_sizes is not None:
- # de-duplicate the sizes provided by the config
- self.compile_sizes = list(set(self.compile_sizes))
- for x in self.compile_sizes:
- if isinstance(x, str):
- assert x == "cudagraph_capture_sizes", \
- "Unrecognized size type in compile_sizes, " \
- f"expect 'cudagraph_capture_sizes', got {x}"
- computed_compile_sizes.extend(self.cudagraph_capture_sizes)
- else:
- assert isinstance(x, int)
- computed_compile_sizes.append(x)
- self.compile_sizes = computed_compile_sizes # type: ignore
-
- # sort to make sure cudagraph capture sizes are in descending order
- self.cudagraph_capture_sizes.sort(reverse=True)
- self.max_capture_size = self.cudagraph_capture_sizes[
- 0] if self.cudagraph_capture_sizes else 0
-
- # pre-compute the mapping from batch size to padded graph size
- self.bs_to_padded_graph_size = [
- 0 for i in range(self.max_capture_size + 1)
- ]
- for end, start in zip(self.cudagraph_capture_sizes,
- self.cudagraph_capture_sizes[1:] + [0]):
- for bs in range(start, end):
- if bs == start:
- self.bs_to_padded_graph_size[bs] = start
- else:
- self.bs_to_padded_graph_size[bs] = end
- self.bs_to_padded_graph_size[
- self.max_capture_size] = self.max_capture_size
-
- def set_splitting_ops_for_v1(self):
- # NOTE: this function needs to be called
- if self.splitting_ops and self.full_cuda_graph:
- raise ValueError("full_cuda_graph cannot be used together with "
- "splitting_ops, as Full CUDA graph will override "
- f"the splitting_ops: {self.splitting_ops}")
-
- if not self.splitting_ops:
- self.splitting_ops = [] if self.full_cuda_graph else [
- "vllm.unified_attention",
- "vllm.unified_attention_with_output",
- "vllm.mamba_mixer2",
- ]
-
-
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@@ -4878,6 +3546,10 @@ class VllmConfig:
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
+ elif not getattr(self.model_config.hf_config, "is_causal", True):
+ disable_chunked_prefill_reasons.append(
+ "Only models using causal attention supports chunked "
+ "prefill and prefix caching; disabling both.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
diff --git a/vllm/config/cache.py b/vllm/config/cache.py
new file mode 100644
index 0000000000000..69cb0d9732fac
--- /dev/null
+++ b/vllm/config/cache.py
@@ -0,0 +1,204 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import hashlib
+from dataclasses import field
+from typing import TYPE_CHECKING, Any, Literal, Optional, get_args
+
+from pydantic import SkipValidation, model_validator
+from pydantic.dataclasses import dataclass
+from typing_extensions import Self
+
+import vllm.envs as envs
+from vllm.config.utils import config
+from vllm.logger import init_logger
+from vllm.utils import GiB_bytes, get_cpu_memory
+
+if TYPE_CHECKING:
+ from vllm.config.parallel import ParallelConfig
+else:
+ ParallelConfig = Any
+
+logger = init_logger(__name__)
+
+BlockSize = Literal[1, 8, 16, 32, 64, 128]
+CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
+PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
+
+
+@config
+@dataclass
+class CacheConfig:
+ """Configuration for the KV cache."""
+
+ block_size: SkipValidation[BlockSize] = None # type: ignore
+ """Size of a contiguous cache block in number of tokens. This is ignored on
+ neuron devices and set to `--max-model-len`. On CUDA devices, only block
+ sizes up to 32 are supported. On HPU devices, block size defaults to 128.
+
+ This config has no static default. If left unspecified by the user, it will
+ be set in `Platform.check_and_update_config()` based on the current
+ platform."""
+ gpu_memory_utilization: float = 0.9
+ """The fraction of GPU memory to be used for the model executor, which can
+ range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
+ utilization. If unspecified, will use the default value of 0.9. This is a
+ per-instance limit, and only applies to the current vLLM instance. It does
+ not matter if you have another vLLM instance running on the same GPU. For
+ example, if you have two vLLM instances running on the same GPU, you can
+ set the GPU memory utilization to 0.5 for each instance."""
+ swap_space: float = 4
+ """Size of the CPU swap space per GPU (in GiB)."""
+ cache_dtype: CacheDType = "auto"
+ """Data type for kv cache storage. If "auto", will use model data type.
+ CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
+ fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
+ is_attention_free: bool = False
+ """Whether the model is attention-free. This is primarily set in
+ `ModelConfig` and that value should be manually duplicated here."""
+ num_gpu_blocks_override: Optional[int] = None
+ """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
+ if specified. Does nothing if `None`. Used for testing preemption."""
+ sliding_window: Optional[int] = None
+ """Sliding window size for the KV cache. This is primarily set in
+ `ModelConfig` and that value should be manually duplicated here."""
+ enable_prefix_caching: Optional[bool] = None
+ """Whether to enable prefix caching. Disabled by default for V0. Enabled by
+ default for V1."""
+ prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
+ """Set the hash algorithm for prefix caching:\n
+ - "builtin" is Python's built-in hash.\n
+ - "sha256" is collision resistant but with certain overheads.
+ This option uses Pickle for object serialization before hashing.\n
+ - "sha256_cbor_64bit" provides a reproducible, cross-language compatible
+ hash. It serializes objects using canonical CBOR and hashes them with
+ SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
+ digest."""
+ cpu_offload_gb: float = 0
+ """The space in GiB to offload to CPU, per GPU. Default is 0, which means
+ no offloading. Intuitively, this argument can be seen as a virtual way to
+ increase the GPU memory size. For example, if you have one 24 GB GPU and
+ set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
+ load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
+ Note that this requires fast CPU-GPU interconnect, as part of the model is
+ loaded from CPU memory to GPU memory on the fly in each model forward pass.
+ """
+ calculate_kv_scales: bool = False
+ """This enables dynamic calculation of `k_scale` and `v_scale` when
+ kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
+ checkpoint if available. Otherwise, the scales will default to 1.0."""
+ cpu_kvcache_space_bytes: Optional[int] = None
+ """(CPU backend only) CPU key-value cache space."""
+ mamba_page_size_padded: Optional[int] = None
+ """ Optional override for mamba page size; used by hybrid mamba/attention
+ models to ensure exact alignment with attention page size."""
+
+ # Will be set after profiling.
+ num_gpu_blocks: Optional[int] = field(default=None, init=False)
+ """The number of blocks to allocate for GPU memory."""
+ num_cpu_blocks: Optional[int] = field(default=None, init=False)
+ """The number of blocks to allocate for CPU memory."""
+
+ kv_sharing_fast_prefill: bool = False
+ """This feature is work in progress and no prefill optimization takes place
+ with this flag enabled currently.
+
+ In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
+ some layers can skip tokens corresponding to prefill. This flag enables
+ attention metadata for eligible layers to be overriden with metadata
+ necessary for implementating this optimization in some models (e.g. Gemma3n)
+ """
+
+ def compute_hash(self) -> str:
+ """
+ WARNING: Whenever a new field is added to this config,
+ ensure that it is included in the factors list if
+ it affects the computation graph.
+
+ Provide a hash that uniquely identifies all the configs
+ that affect the structure of the computation
+ graph from input ids/embeddings to the final hidden states,
+ excluding anything before input ids/embeddings and after
+ the final hidden states.
+ """
+ factors: list[Any] = []
+ factors.append(self.cache_dtype)
+ # `cpu_offload_gb` does not use `torch.compile` yet.
+ hash_str = hashlib.md5(str(factors).encode(),
+ usedforsecurity=False).hexdigest()
+ return hash_str
+
+ def __post_init__(self) -> None:
+ self.swap_space_bytes = self.swap_space * GiB_bytes
+
+ self._verify_cache_dtype()
+ self._verify_prefix_caching()
+
+ def metrics_info(self):
+ # convert cache_config to dict(key: str, value: str) for prometheus
+ # metrics info
+ return {key: str(value) for key, value in self.__dict__.items()}
+
+ @model_validator(mode='after')
+ def _verify_args(self) -> Self:
+ if self.cpu_offload_gb < 0:
+ raise ValueError("CPU offload space must be non-negative"
+ f", but got {self.cpu_offload_gb}")
+
+ if self.gpu_memory_utilization > 1.0:
+ raise ValueError(
+ "GPU memory utilization must be less than 1.0. Got "
+ f"{self.gpu_memory_utilization}.")
+
+ if self.kv_sharing_fast_prefill:
+ logger.warning_once(
+ "--kv-sharing-fast-prefill is currently work in progress "
+ "and not functional yet (i.e. no prefill savings)")
+
+ return self
+
+ def _verify_cache_dtype(self) -> None:
+ if self.cache_dtype == "auto":
+ pass
+ elif self.cache_dtype in get_args(CacheDType):
+ logger.info(
+ "Using fp8 data type to store kv cache. It reduces the GPU "
+ "memory footprint and boosts the performance. "
+ "Meanwhile, it may cause accuracy drop without a proper "
+ "scaling factor.")
+ else:
+ raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
+
+ def _verify_prefix_caching(self) -> None:
+ if not self.enable_prefix_caching:
+ return
+
+ if self.sliding_window is not None and not envs.VLLM_USE_V1:
+ raise NotImplementedError(
+ "Prefix caching is not supported with sliding window. "
+ "Run with --disable-sliding-window to use prefix caching.")
+
+ if (self.enable_prefix_caching and self.prefix_caching_hash_algo
+ not in get_args(PrefixCachingHashAlgo)):
+ raise ValueError(
+ "Unknown prefix caching hash algorithm: "
+ f"{self.prefix_caching_hash_algo}. Must be one of "
+ f"{get_args(PrefixCachingHashAlgo)}.")
+
+ def verify_with_parallel_config(
+ self,
+ parallel_config: ParallelConfig,
+ ) -> None:
+ total_cpu_memory = get_cpu_memory()
+ # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
+ # group are in the same node. However, the GPUs may span multiple nodes.
+ num_gpus_per_node = parallel_config.tensor_parallel_size
+ cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
+
+ msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
+ f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
+ "is allocated for the swap space.")
+ if cpu_memory_usage > 0.7 * total_cpu_memory:
+ raise ValueError("Too large swap space. " + msg)
+ elif cpu_memory_usage > 0.4 * total_cpu_memory:
+ logger.warning("Possibly too large swap space. %s", msg)
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
new file mode 100644
index 0000000000000..8a78d811b9a25
--- /dev/null
+++ b/vllm/config/compilation.py
@@ -0,0 +1,428 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import hashlib
+from collections import Counter
+from dataclasses import asdict, field
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+from pydantic import TypeAdapter
+from pydantic.dataclasses import dataclass
+
+import vllm.envs as envs
+from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
+from vllm.config.utils import config
+from vllm.logger import init_logger
+from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
+
+if TYPE_CHECKING:
+ from vllm.config import VllmConfig
+else:
+ VllmConfig = object
+
+logger = init_logger(__name__)
+
+
+class CompilationLevel:
+ # constants for the levels of the compilation process
+ NO_COMPILATION = 0
+ DYNAMO_AS_IS = 1
+ DYNAMO_ONCE = 2
+ PIECEWISE = 3
+
+
+@config
+@dataclass
+class PassConfig:
+ """Configuration for custom Inductor passes.
+
+ This is separate from general `CompilationConfig` so that inductor passes
+ don't all have access to full configuration - that would create a cycle as
+ the `PassManager` is set as a property of config."""
+
+ enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
+ """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
+ enable_attn_fusion: bool = False
+ """Whether to enable the custom attention+quant fusion pass."""
+ enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
+ """Whether to enable the custom no-op elimination pass."""
+ enable_sequence_parallelism: bool = False
+ """Whether to enable sequence parallelism."""
+ enable_async_tp: bool = False
+ """Whether to enable async TP."""
+ enable_fi_allreduce_fusion: bool = False
+ """Whether to enable flashinfer allreduce fusion."""
+ fi_allreduce_fusion_max_token_num: int = 16384
+ """Max number of tokens to used in flashinfer allreduce fusion."""
+
+ # TODO(luka) better pass enabling system.
+
+ def uuid(self):
+ """
+ Produces a hash unique to the pass configuration.
+ Any new fields that affect compilation should be added to the hash.
+ Any future fields that don't affect compilation should be excluded.
+ """
+ return InductorPass.hash_dict(asdict(self))
+
+ def __post_init__(self) -> None:
+ if not self.enable_noop:
+ if self.enable_fusion:
+ logger.warning_once(
+ "Fusion enabled but reshape elimination disabled. "
+ "RMSNorm/SiluMul + quant (fp8) fusion might not work")
+ if self.enable_attn_fusion:
+ logger.warning_once(
+ "Fusion enabled but reshape elimination disabled. "
+ "Attention + quant (fp8) fusion might not work")
+
+
+@config
+@dataclass
+class CompilationConfig:
+ """Configuration for compilation. It has three parts:
+
+ - Top-level Compilation control:
+ - [`level`][vllm.config.CompilationConfig.level]
+ - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
+ - [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
+ - [`backend`][vllm.config.CompilationConfig.backend]
+ - [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
+ - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
+ - CudaGraph capture:
+ - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
+ - [`cudagraph_capture_sizes`]
+ [vllm.config.CompilationConfig.cudagraph_capture_sizes]
+ - [`cudagraph_num_of_warmups`]
+ [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
+ - [`cudagraph_copy_inputs`]
+ [vllm.config.CompilationConfig.cudagraph_copy_inputs]
+ - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
+ - Inductor compilation:
+ - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
+ - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
+ - [`inductor_compile_config`]
+ [vllm.config.CompilationConfig.inductor_compile_config]
+ - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
+ - custom inductor passes
+
+ Why we have different sizes for cudagraph and inductor:
+ - cudagraph: a cudagraph captured for a specific size can only be used
+ for the same size. We need to capture all the sizes we want to use.
+ - inductor: a graph compiled by inductor for a general shape can be used
+ for different sizes. Inductor can also compile for specific sizes,
+ where it can have more information to optimize the graph with fully
+ static shapes. However, we find the general shape compilation is
+ sufficient for most cases. It might be beneficial to compile for
+ certain small batchsizes, where inductor is good at optimizing.
+ """
+ # Top-level Compilation control
+ level: Optional[int] = None
+ """The level of compilation:
+
+ - None: If None, we will select the default compilation level.
+ For V1 engine this is 3, for V0 engine this is 0.
+ - 0: no compilation.
+ - 1: dynamo as is.
+ - 2: dynamo once.
+ - 3: piecewise compilation."""
+ debug_dump_path: str = ""
+ """The path to dump the debug information."""
+ cache_dir: str = ""
+ """The directory to store the compiled graph, to accelerate Inductor
+ compilation. By default, it will use model-related information to generate
+ a cache directory."""
+ backend: str = ""
+ """The backend for compilation. It needs to be a string:
+
+ - "" (empty string): use the default backend.
+ - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
+ - "full.module.name": a qualified name which can be used to import the
+
+ backend function.
+ We use string to avoid serialization issues when using compilation in a
+ distributed setting. When the compilation level is 1 or 2, the backend is
+ used for the compilation directly (it sees the whole graph). When the
+ compilation level is 3, the backend is used for the piecewise compilation
+ (it sees a part of the graph)."""
+ custom_ops: list[str] = field(default_factory=list)
+ """Fine-grained control over which custom ops to enable/disable. Use 'all'
+ to enable all, 'none' to disable all. Also specify a list of custom op
+ names to enable (prefixed with a '+'), or disable (prefixed with a '-').
+ Examples:
+
+ - 'all,-op1' to enable all except op1
+ - 'none,+op1,+op2' to enable only op1 and op2
+
+ By default, all custom ops are enabled when running without Inductor and
+ disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
+ Inductor generates (fused) Triton kernels for disabled custom ops."""
+ splitting_ops: list[str] = field(default_factory=list)
+ """A list of ops to split the full graph into subgraphs, used in piecewise
+ compilation."""
+
+ # Inductor capture
+ use_inductor: bool = True
+ """Whether to use inductor compilation:
+
+ - False: inductor compilation is not used. graph runs in eager
+ (custom_ops enabled by default).
+ - True: inductor compilation is used (custom_ops disabled by default).
+ One graph for symbolic shape and one graph per size in compile_sizes
+ are compiled using configurations in inductor_compile_config.
+
+ This setting is ignored if level1."""
+
+ def compute_hash(self) -> str:
+ """
+ WARNING: Whenever a new field is added to this config,
+ ensure that it is included in the factors list if
+ it affects the computation graph.
+
+ Provide a hash that uniquely identifies all the configs
+ that affect the structure of the computation
+ graph from input ids/embeddings to the final hidden states,
+ excluding anything before input ids/embeddings and after
+ the final hidden states.
+ """
+ factors: list[Any] = []
+ factors.append(self.level)
+ factors.append(self.backend)
+ factors.append(self.custom_ops)
+ factors.append(self.splitting_ops)
+ factors.append(self.use_inductor)
+ factors.append(self.inductor_compile_config)
+ factors.append(self.inductor_passes)
+ factors.append(self.pass_config.uuid())
+ return hashlib.sha256(str(factors).encode()).hexdigest()
+
+ def __repr__(self) -> str:
+ exclude = {
+ "static_forward_context": True,
+ "enabled_custom_ops": True,
+ "disabled_custom_ops": True,
+ "compilation_time": True,
+ "bs_to_padded_graph_size": True,
+ "traced_files": True,
+ "inductor_compile_config": {
+ "post_grad_custom_post_pass": True,
+ },
+ }
+
+ # exclude default attr in pass_config
+ pass_config_exclude = {}
+ for attr, default_val in vars(PassConfig()).items():
+ if getattr(self.pass_config, attr) == default_val:
+ pass_config_exclude[attr] = True
+ if pass_config_exclude:
+ exclude["pass_config"] = pass_config_exclude
+
+ return TypeAdapter(CompilationConfig).dump_json(
+ self,
+ exclude=exclude, # type: ignore[arg-type]
+ exclude_unset=True).decode()
+
+ __str__ = __repr__
+
+ def __post_init__(self) -> None:
+ count_none = self.custom_ops.count("none")
+ count_all = self.custom_ops.count("all")
+ assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
+
+ # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
+ # 1. A bug in PyTorch, fixed in 2.7:
+ # https://github.com/pytorch/pytorch/issues/147924
+ # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
+ # work with V2. Addressing this will take extra engineering effort
+ # and it is not yet a priority. RFC here:
+ # https://github.com/vllm-project/vllm/issues/14703
+
+ if is_torch_equal_or_newer("2.6"):
+ KEY = 'enable_auto_functionalized_v2'
+ if KEY not in self.inductor_compile_config:
+ self.inductor_compile_config[KEY] = False
+
+ for k, v in self.inductor_passes.items():
+ if not isinstance(v, str):
+ assert callable(v), (
+ f"pass {k} should be callable or a qualified name")
+ self.inductor_compile_config[k] = v if isinstance(
+ v, InductorPass) else CallableInductorPass(v)
+ continue
+
+ # resolve function from qualified name
+ names = v.split(".")
+ module = ".".join(names[:-1])
+ func_name = names[-1]
+ func = __import__(module).__dict__[func_name]
+ self.inductor_compile_config[k] = func if isinstance(
+ func, InductorPass) else CallableInductorPass(func)
+
+ if isinstance(self.pass_config, dict):
+ self.pass_config = PassConfig(**self.pass_config)
+
+ def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
+ if self.level == CompilationLevel.NO_COMPILATION:
+ raise ValueError("No compilation level is set.")
+
+ from torch._dynamo.backends.registry import list_backends
+ torch_backends = list_backends(exclude_tags=tuple())
+ if self.level in [
+ CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
+ ]:
+ if self.backend == "":
+ return "eager"
+ if self.backend in torch_backends:
+ return self.backend
+ return resolve_obj_by_qualname(self.backend)
+
+ # TODO: pass user-specified backend to piecewise compilation
+ # merge with the config use_inductor
+ assert self.level == CompilationLevel.PIECEWISE
+
+ from vllm.compilation.backends import VllmBackend
+ return VllmBackend(vllm_config)
+
+ def init_with_cudagraph_sizes(self,
+ cudagraph_capture_sizes: list[int]) -> None:
+ """To complete the initialization of config,
+ we need to know the cudagraph sizes."""
+
+ if self.cudagraph_capture_sizes is None:
+ self.cudagraph_capture_sizes = cudagraph_capture_sizes
+ else:
+ # de-duplicate the sizes provided by the config
+ dedup_sizes = list(set(self.cudagraph_capture_sizes))
+ if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
+ logger.info(("cudagraph sizes specified by model runner"
+ " %s is overridden by config %s"),
+ cudagraph_capture_sizes, dedup_sizes)
+ self.cudagraph_capture_sizes = dedup_sizes
+
+ computed_compile_sizes = []
+ if self.compile_sizes is not None:
+ # de-duplicate the sizes provided by the config
+ self.compile_sizes = list(set(self.compile_sizes))
+ for x in self.compile_sizes:
+ if isinstance(x, str):
+ assert x == "cudagraph_capture_sizes", \
+ "Unrecognized size type in compile_sizes, " \
+ f"expect 'cudagraph_capture_sizes', got {x}"
+ computed_compile_sizes.extend(self.cudagraph_capture_sizes)
+ else:
+ assert isinstance(x, int)
+ computed_compile_sizes.append(x)
+ self.compile_sizes = computed_compile_sizes # type: ignore
+
+ # sort to make sure cudagraph capture sizes are in descending order
+ self.cudagraph_capture_sizes.sort(reverse=True)
+ self.max_capture_size = self.cudagraph_capture_sizes[
+ 0] if self.cudagraph_capture_sizes else 0
+
+ # pre-compute the mapping from batch size to padded graph size
+ self.bs_to_padded_graph_size = [
+ 0 for i in range(self.max_capture_size + 1)
+ ]
+ for end, start in zip(self.cudagraph_capture_sizes,
+ self.cudagraph_capture_sizes[1:] + [0]):
+ for bs in range(start, end):
+ if bs == start:
+ self.bs_to_padded_graph_size[bs] = start
+ else:
+ self.bs_to_padded_graph_size[bs] = end
+ self.bs_to_padded_graph_size[
+ self.max_capture_size] = self.max_capture_size
+
+ def set_splitting_ops_for_v1(self):
+ # NOTE: this function needs to be called
+ if self.splitting_ops and self.full_cuda_graph:
+ raise ValueError("full_cuda_graph cannot be used together with "
+ "splitting_ops, as Full CUDA graph will override "
+ f"the splitting_ops: {self.splitting_ops}")
+
+ if not self.splitting_ops:
+ self.splitting_ops = [] if self.full_cuda_graph else [
+ "vllm.unified_attention",
+ "vllm.unified_attention_with_output",
+ "vllm.mamba_mixer2",
+ ]
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
new file mode 100644
index 0000000000000..bac1e63800d7b
--- /dev/null
+++ b/vllm/config/parallel.py
@@ -0,0 +1,375 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import hashlib
+from dataclasses import field
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+
+import torch
+from pydantic import model_validator
+from pydantic.dataclasses import dataclass
+from torch.distributed import ProcessGroup, ReduceOp
+from typing_extensions import Self
+
+import vllm.envs as envs
+from vllm.config.utils import config
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.utils import cuda_device_count_stateless, get_open_port
+
+if TYPE_CHECKING:
+ from ray.runtime_env import RuntimeEnv
+ from ray.util.placement_group import PlacementGroup
+
+ from vllm.executor.executor_base import ExecutorBase
+else:
+ RuntimeEnv = Any
+ PlacementGroup = Any
+ ExecutorBase = Any
+
+logger = init_logger(__name__)
+
+DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
+
+
+@config
+@dataclass
+class ParallelConfig:
+ """Configuration for the distributed execution."""
+
+ pipeline_parallel_size: int = 1
+ """Number of pipeline parallel groups."""
+ tensor_parallel_size: int = 1
+ """Number of tensor parallel groups."""
+ data_parallel_size: int = 1
+ """Number of data parallel groups. MoE layers will be sharded according to
+ the product of the tensor parallel size and data parallel size."""
+ data_parallel_size_local: int = 1
+ """Number of local data parallel groups."""
+ data_parallel_rank: int = 0
+ """Rank of the data parallel group."""
+ data_parallel_rank_local: Optional[int] = None
+ """Local rank of the data parallel group,
+ set only in SPMD mode."""
+ data_parallel_master_ip: str = "127.0.0.1"
+ """IP of the data parallel master."""
+ data_parallel_rpc_port: int = 29550
+ """Port for data parallel messaging."""
+ data_parallel_master_port: int = 29500
+ """Port of the data parallel master."""
+ data_parallel_backend: str = "mp"
+ """Backend to use for data parallel, either "mp" or "ray"."""
+ data_parallel_external_lb: bool = False
+ """Whether to use "external" DP LB mode. Applies only to online serving
+ and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
+ wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
+ is provided explicitly to vllm serve."""
+ data_parallel_hybrid_lb: bool = False
+ """Whether to use "hybrid" DP LB mode. Applies only to online serving
+ and when data_parallel_size > 0. Enables running an AsyncLLM
+ and API server on a "per-node" basis where vLLM load balances
+ between local data parallel ranks, but an external LB balances
+ between vLLM nodes/replicas. Set explicitly in conjunction with
+ --data-parallel-start-rank."""
+ enable_expert_parallel: bool = False
+ """Use expert parallelism instead of tensor parallelism for MoE layers."""
+ enable_eplb: bool = False
+ """Enable expert parallelism load balancing for MoE layers."""
+ num_redundant_experts: int = 0
+ """Number of redundant experts to use for expert parallelism."""
+ eplb_window_size: int = 1000
+ """Window size for expert load recording."""
+ eplb_step_interval: int = 3000
+ """
+ Interval for rearranging experts in expert parallelism.
+
+ Note that if this is greater than the EPLB window size, only the metrics
+ of the last `eplb_window_size` steps will be used for rearranging experts.
+ """
+ eplb_log_balancedness: bool = False
+ """
+ Log the balancedness each step of expert parallelism.
+ This is turned off by default since it will cause communication overhead.
+ """
+
+ max_parallel_loading_workers: Optional[int] = None
+ """Maximum number of parallel loading workers when loading model
+ sequentially in multiple batches. To avoid RAM OOM when using tensor
+ parallel and large models."""
+
+ disable_custom_all_reduce: bool = False
+ """Disable the custom all-reduce kernel and fall back to NCCL."""
+
+ ray_workers_use_nsight: bool = False
+ """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
+
+ ray_runtime_env: Optional[RuntimeEnv] = None
+ """Ray runtime environment to pass to distributed workers."""
+
+ placement_group: Optional[PlacementGroup] = None
+ """ray distributed model workers placement group."""
+
+ distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
+ type[ExecutorBase]]] = None
+ """Backend to use for distributed model
+ workers, either "ray" or "mp" (multiprocessing). If the product
+ of pipeline_parallel_size and tensor_parallel_size is less than
+ or equal to the number of GPUs available, "mp" will be used to
+ keep processing on a single host. Otherwise, this will default
+ to "ray" if Ray is installed and fail otherwise. Note that tpu
+ only support Ray for distributed inference."""
+
+ worker_cls: str = "auto"
+ """The full name of the worker class to use. If "auto", the worker class
+ will be determined based on the platform."""
+ sd_worker_cls: str = "auto"
+ """The full name of the worker class to use for speculative decoding.
+ If "auto", the worker class will be determined based on the platform."""
+ worker_extension_cls: str = ""
+ """The full name of the worker extension class to use. The worker extension
+ class is dynamically inherited by the worker class. This is used to inject
+ new attributes and methods to the worker class for use in collective_rpc
+ calls."""
+
+ world_size: int = field(init=False)
+ """world_size is TPxPP, it affects the number of workers we create."""
+
+ rank: int = 0
+ """Global rank in distributed setup."""
+
+ enable_multimodal_encoder_data_parallel: bool = False
+ """ Use data parallelism instead of tensor parallelism for vision encoder.
+ Only support LLama4 for now"""
+
+ @property
+ def world_size_across_dp(self) -> int:
+ """world_size_across_dp is TPxPPxDP, it is the size of the world
+ including data parallelism."""
+ return self.world_size * self.data_parallel_size
+
+ def get_next_dp_init_port(self) -> int:
+ """
+ We might need to initialize process groups in multiple
+ processes that is related to data parallelism,
+ e.g. both in the worker and in the engine, which
+ can live in different processes. To avoid port conflicts, we
+ increment the port number each time we need to initialize a
+ new process group related to data parallelism.
+ """
+ answer = self.data_parallel_master_port
+ self.data_parallel_master_port += 1
+ return answer
+
+ def stateless_init_dp_group(self) -> ProcessGroup:
+ # NOTE: In high-concurrency scenarios multiple processes
+ # can pick the same (currently free) port through a race
+ # condition when calling `get_open_port()`. When the first
+ # process binds the port the others will subsequently fail
+ # with `torch.distributed.DistNetworkError: EADDRINUSE`.
+ # To make the initialization more robust we retry a few times
+ # with a fresh port whenever this specific error is observed.
+ from torch.distributed import DistNetworkError
+
+ from vllm.distributed.utils import (
+ stateless_init_torch_distributed_process_group)
+
+ max_retries = 5
+ last_exc: Optional[Exception] = None
+ for _ in range(max_retries):
+ try:
+ # use gloo since the engine process might not have cuda device
+ return stateless_init_torch_distributed_process_group(
+ self.data_parallel_master_ip,
+ self.get_next_dp_init_port(),
+ self.data_parallel_rank,
+ self.data_parallel_size,
+ backend="gloo")
+ except DistNetworkError as e:
+ # We only want to retry when the root cause is EADDRINUSE.
+ if "EADDRINUSE" in str(e):
+ logger.warning(
+ "Address already in use. Retrying with a new port.")
+ last_exc = e
+ continue # try again with a new port
+ raise e
+
+ # If we get here all retries have failed.
+ assert last_exc is not None
+ raise last_exc
+
+ @staticmethod
+ def has_unfinished_dp(dp_group: ProcessGroup,
+ has_unfinished: bool) -> bool:
+ tensor = torch.tensor([has_unfinished],
+ dtype=torch.int32,
+ device="cpu")
+ # dp rank 0: has_unfinished_seqs=True
+ # dp rank 1: has_unfinished_seqs=False
+ # aggregated: has_unfinished_seqs=True
+ # so this is an OR operation, i.e. MAX in integers
+ torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
+ aggregated_has_unfinished = bool(tensor.item())
+ return aggregated_has_unfinished
+
+ @staticmethod
+ def sync_kv_cache_memory_size(dp_group: ProcessGroup,
+ kv_cache_memory: int) -> int:
+ if kv_cache_memory == -1:
+ kv_cache_memory = torch.iinfo(torch.int64).max
+ tensor = torch.tensor([kv_cache_memory],
+ dtype=torch.int64,
+ device="cpu")
+ # we cannot use broadcast for stateless dp group since it depends
+ # on global rank
+ torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
+ return tensor.item()
+
+ def compute_hash(self):
+ """
+ Provide a hash that uniquely identifies all the configs
+ that affect the structure of the computation
+ graph from input ids/embeddings to the final hidden states,
+ excluding anything before input ids/embeddings and after
+ the final hidden states.
+ """
+ factors: list[Any] = []
+ factors.append(self.pipeline_parallel_size)
+ factors.append(self.tensor_parallel_size)
+ factors.append(self.enable_expert_parallel)
+ factors.append(self.data_parallel_size)
+ factors.append(envs.VLLM_ALL2ALL_BACKEND)
+ return hashlib.sha256(str(factors).encode()).hexdigest()
+
+ def __post_init__(self) -> None:
+ self.world_size = self.pipeline_parallel_size * \
+ self.tensor_parallel_size
+
+ if self.data_parallel_size_local > self.data_parallel_size:
+ raise ValueError(
+ f"data_parallel_size_local ({self.data_parallel_size_local}) "
+ f"must be <= data_parallel_size ({self.data_parallel_size})")
+
+ if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
+ # Data parallel was specified in the engine args.
+ self.data_parallel_master_port = get_open_port()
+
+ if not (0 <= self.data_parallel_rank < self.data_parallel_size):
+ raise ValueError(
+ f"data_parallel_rank ({self.data_parallel_rank})"
+ f" must be in the range [0, {self.data_parallel_size})")
+ else:
+ # Otherwise fall back to env vars (e.g. for offline SPMD case).
+ self.data_parallel_size = envs.VLLM_DP_SIZE
+ self.data_parallel_rank = envs.VLLM_DP_RANK
+ self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
+ self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
+ self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
+
+ if self.data_parallel_external_lb:
+ raise ValueError("data_parallel_external_lb can only "
+ "be set when data_parallel_size > 1")
+
+ if self.distributed_executor_backend == "external_launcher":
+ import os
+ os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
+ logger.info("Disabling V1 multiprocessing for external launcher.")
+
+ if self.enable_eplb:
+ if not current_platform.is_cuda():
+ raise ValueError(
+ "Expert parallelism load balancing is only supported on "
+ "CUDA devices now.")
+ if self.num_redundant_experts < 0:
+ raise ValueError(
+ "num_redundant_experts must be non-negative, but got "
+ f"{self.num_redundant_experts}.")
+ if not self.enable_expert_parallel:
+ raise ValueError(
+ "enable_expert_parallel must be True to use EPLB.")
+ if self.tensor_parallel_size * self.data_parallel_size <= 1:
+ raise ValueError(
+ "EPLB requires tensor_parallel_size or data_parallel_size "
+ f"to be greater than 1, but got "
+ f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
+ )
+ else:
+ if self.num_redundant_experts != 0:
+ raise ValueError(
+ "num_redundant_experts should be used with EPLB."
+ f"{self.num_redundant_experts}.")
+ if self.distributed_executor_backend is None and self.world_size > 1:
+ # We use multiprocessing by default if world_size fits on the
+ # current node and we aren't in a ray placement group.
+
+ from vllm.executor import ray_utils
+ backend: DistributedExecutorBackend = "mp"
+ ray_found = ray_utils.ray_is_available()
+ if current_platform.is_neuron():
+ # neuron uses single process to control multiple devices
+ backend = "uni"
+ elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
+ backend = "uni"
+ elif (current_platform.is_cuda()
+ and cuda_device_count_stateless() < self.world_size):
+ if not ray_found:
+ raise ValueError("Unable to load Ray: "
+ f"{ray_utils.ray_import_err}. Ray is "
+ "required for multi-node inference, "
+ "please install Ray with `pip install "
+ "ray`.")
+ backend = "ray"
+ elif self.data_parallel_backend == "ray":
+ logger.info("Using ray distributed inference because "
+ "data_parallel_backend is ray")
+ backend = "ray"
+ elif ray_found:
+ if self.placement_group:
+ backend = "ray"
+ else:
+ from ray import is_initialized as ray_is_initialized
+ if ray_is_initialized():
+ from ray.util import get_current_placement_group
+ if get_current_placement_group():
+ backend = "ray"
+ self.distributed_executor_backend = backend
+ logger.debug("Defaulting to use %s for distributed inference",
+ backend)
+
+ if self.distributed_executor_backend is None and self.world_size == 1:
+ self.distributed_executor_backend = "uni"
+
+ @property
+ def use_ray(self) -> bool:
+ return self.distributed_executor_backend == "ray" or (
+ isinstance(self.distributed_executor_backend, type)
+ and self.distributed_executor_backend.uses_ray)
+
+ @model_validator(mode='after')
+ def _verify_args(self) -> Self:
+ # Lazy import to avoid circular import
+ from vllm.executor.executor_base import ExecutorBase
+ from vllm.platforms import current_platform
+ if self.distributed_executor_backend not in (
+ "ray", "mp", "uni",
+ "external_launcher", None) and not (isinstance(
+ self.distributed_executor_backend, type) and issubclass(
+ self.distributed_executor_backend, ExecutorBase)):
+ raise ValueError(
+ "Unrecognized distributed executor backend "
+ f"{self.distributed_executor_backend}. Supported "
+ "values are 'ray', 'mp' 'uni', 'external_launcher' or"
+ " custom ExecutorBase subclass.")
+ if self.use_ray:
+ from vllm.executor import ray_utils
+ ray_utils.assert_ray_available()
+
+ if not current_platform.use_custom_allreduce():
+ self.disable_custom_all_reduce = True
+ logger.debug(
+ "Disabled the custom all-reduce kernel because it is not "
+ "supported on current platform.")
+ if self.ray_workers_use_nsight and not self.use_ray:
+ raise ValueError("Unable to use nsight profiling unless workers "
+ "run with Ray.")
+
+ return self
diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py
new file mode 100644
index 0000000000000..db669600a0cc3
--- /dev/null
+++ b/vllm/config/scheduler.py
@@ -0,0 +1,329 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import hashlib
+from dataclasses import field
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+
+from pydantic import SkipValidation, model_validator
+from pydantic.dataclasses import dataclass
+from typing_extensions import Self
+
+from vllm.config.utils import config
+from vllm.logger import init_logger
+from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
+ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
+ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
+
+if TYPE_CHECKING:
+ from vllm.config import RunnerType
+else:
+ RunnerType = Any
+
+logger = init_logger(__name__)
+
+PreemptionMode = Literal["swap", "recompute"]
+SchedulerPolicy = Literal["fcfs", "priority"]
+
+
+@config
+@dataclass
+class SchedulerConfig:
+ """Scheduler configuration."""
+
+ runner_type: RunnerType = "generate"
+ """The runner type to launch for the model."""
+
+ max_num_batched_tokens: SkipValidation[int] = None # type: ignore
+ """Maximum number of tokens to be processed in a single iteration.
+
+ This config has no static default. If left unspecified by the user, it will
+ be set in `EngineArgs.create_engine_config` based on the usage context."""
+
+ max_num_seqs: SkipValidation[int] = None # type: ignore
+ """Maximum number of sequences to be processed in a single iteration.
+
+ This config has no static default. If left unspecified by the user, it will
+ be set in `EngineArgs.create_engine_config` based on the usage context."""
+
+ max_model_len: SkipValidation[int] = None # type: ignore
+ """Maximum length of a sequence (including prompt and generated text). This
+ is primarily set in `ModelConfig` and that value should be manually
+ duplicated here."""
+
+ max_num_partial_prefills: int = 1
+ """For chunked prefill, the maximum number of sequences that can be
+ partially prefilled concurrently."""
+
+ max_long_partial_prefills: int = 1
+ """For chunked prefill, the maximum number of prompts longer than
+ long_prefill_token_threshold that will be prefilled concurrently. Setting
+ this less than max_num_partial_prefills will allow shorter prompts to jump
+ the queue in front of longer prompts in some cases, improving latency."""
+
+ long_prefill_token_threshold: int = 0
+ """For chunked prefill, a request is considered long if the prompt is
+ longer than this number of tokens."""
+
+ num_lookahead_slots: int = 0
+ """The number of slots to allocate per sequence per
+ step, beyond the known token ids. This is used in speculative
+ decoding to store KV activations of tokens which may or may not be
+ accepted.
+
+ NOTE: This will be replaced by speculative config in the future; it is
+ present to enable correctness tests until then."""
+
+ cuda_graph_sizes: list[int] = field(default_factory=list)
+ """Cuda graph capture sizes
+ 1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
+ 2. if one value is provided, then the capture list would follow the
+ pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
+ 3. more than one value (e.g. 1 2 128) is provided, then the capture list
+ will follow the provided list."""
+
+ delay_factor: float = 0.0
+ """Apply a delay (of delay factor multiplied by previous
+ prompt latency) before scheduling next prompt."""
+
+ enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
+ """If True, prefill requests can be chunked based
+ on the remaining max_num_batched_tokens."""
+
+ is_multimodal_model: bool = False
+ """True if the model is multimodal."""
+
+ # TODO (ywang96): Make this configurable.
+ max_num_encoder_input_tokens: int = field(init=False)
+ """Multimodal encoder compute budget, only used in V1.
+
+ NOTE: This is not currently configurable. It will be overridden by
+ max_num_batched_tokens in case max multimodal embedding size is larger."""
+
+ # TODO (ywang96): Make this configurable.
+ encoder_cache_size: int = field(init=False)
+ """Multimodal encoder cache size, only used in V1.
+
+ NOTE: This is not currently configurable. It will be overridden by
+ max_num_batched_tokens in case max multimodal embedding size is larger."""
+
+ preemption_mode: Optional[PreemptionMode] = None
+ """Whether to perform preemption by swapping or
+ recomputation. If not specified, we determine the mode as follows:
+ We use recomputation by default since it incurs lower overhead than
+ swapping. However, when the sequence group has multiple sequences
+ (e.g., beam search), recomputation is not currently supported. In
+ such a case, we use swapping instead."""
+
+ num_scheduler_steps: int = 1
+ """Maximum number of forward steps per scheduler call."""
+
+ multi_step_stream_outputs: bool = True
+ """If False, then multi-step will stream outputs at the end of all steps"""
+
+ send_delta_data: bool = False
+ """Private API. If used, scheduler sends delta data to
+ workers instead of an entire data. It should be enabled only
+ when SPMD worker architecture is enabled. I.e.,
+ VLLM_USE_RAY_SPMD_WORKER=1"""
+
+ policy: SchedulerPolicy = "fcfs"
+ """The scheduling policy to use:\n
+ - "fcfs" means first come first served, i.e. requests are handled in order
+ of arrival.\n
+ - "priority" means requests are handled based on given priority (lower
+ value means earlier handling) and time of arrival deciding any ties)."""
+
+ chunked_prefill_enabled: bool = field(init=False)
+ """True if chunked prefill is enabled."""
+
+ disable_chunked_mm_input: bool = False
+ """If set to true and chunked prefill is enabled, we do not want to
+ partially schedule a multimodal item. Only used in V1
+ This ensures that if a request has a mixed prompt
+ (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
+ some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
+ it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
+
+ # scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
+ # or "mod.custom_class".
+ scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
+ """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
+ default scheduler. Can be a class directly or the path to a class of form
+ "mod.custom_class"."""
+
+ disable_hybrid_kv_cache_manager: bool = False
+ """If set to True, KV cache manager will allocate the same size of KV cache
+ for all attention layers even if there are multiple type of attention layers
+ like full attention and sliding window attention.
+ """
+
+ async_scheduling: bool = False
+ """EXPERIMENTAL: If set to True, perform async scheduling. This may help
+ reduce the CPU overheads, leading to better latency and throughput. However,
+ async scheduling is currently not supported with some features such as
+ structured outputs, speculative decoding, and pipeline parallelism.
+ """
+
+ def compute_hash(self) -> str:
+ """
+ WARNING: Whenever a new field is added to this config,
+ ensure that it is included in the factors list if
+ it affects the computation graph.
+
+ Provide a hash that uniquely identifies all the configs
+ that affect the structure of the computation
+ graph from input ids/embeddings to the final hidden states,
+ excluding anything before input ids/embeddings and after
+ the final hidden states.
+ """
+ # no factors to consider.
+ # this config will not affect the computation graph.
+ factors: list[Any] = []
+ hash_str = hashlib.md5(str(factors).encode(),
+ usedforsecurity=False).hexdigest()
+ return hash_str
+
+ def __post_init__(self) -> None:
+ if self.max_model_len is None:
+ self.max_model_len = 8192
+
+ if self.max_num_seqs is None:
+ self.max_num_seqs = 128
+
+ if self.max_num_batched_tokens is None:
+ if self.enable_chunked_prefill:
+ if self.num_scheduler_steps > 1:
+ # Multi-step Chunked-Prefill doesn't allow prompt-chunking
+ # for now. Have max_num_batched_tokens set to max_model_len
+ # so we don't reject sequences on account of a short
+ # max_num_batched_tokens.
+ self.max_num_batched_tokens = max(
+ self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
+ else:
+ self.max_num_batched_tokens = (
+ DEFAULT_MAX_NUM_BATCHED_TOKENS)
+ else:
+ # If max_model_len is too short, use
+ # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
+ # for higher throughput.
+ self.max_num_batched_tokens = max(
+ self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
+
+ if self.runner_type == "pooling":
+ # Choose specific value for higher throughput
+ self.max_num_batched_tokens = max(
+ self.max_num_batched_tokens,
+ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
+ )
+ if self.is_multimodal_model:
+ # The value needs to be at least the number of multimodal tokens
+ self.max_num_batched_tokens = max(
+ self.max_num_batched_tokens,
+ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
+ )
+
+ # When using default settings,
+ # Ensure max_num_batched_tokens does not exceed model limit.
+ # Some models (e.g., Whisper) have embeddings tied to max length.
+ self.max_num_batched_tokens = min(
+ self.max_num_seqs * self.max_model_len,
+ self.max_num_batched_tokens)
+
+ self.max_num_encoder_input_tokens = self.max_num_batched_tokens
+ self.encoder_cache_size = self.max_num_batched_tokens
+
+ if self.enable_chunked_prefill:
+ logger.info(
+ "Chunked prefill is enabled with max_num_batched_tokens=%d.",
+ self.max_num_batched_tokens)
+
+ self.chunked_prefill_enabled = self.enable_chunked_prefill
+ if self.max_num_partial_prefills > 1:
+ if self.long_prefill_token_threshold == 0:
+ self.long_prefill_token_threshold = int(self.max_model_len *
+ 0.04)
+
+ logger.info(
+ "Concurrent partial prefills enabled with "
+ "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
+ "long_prefill_token_threshold=%d",
+ self.max_num_partial_prefills, self.max_long_partial_prefills,
+ self.long_prefill_token_threshold)
+
+ # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
+ # This avoids OOM in tight memory scenarios with small max_num_seqs,
+ # and prevents capture of many large graphs (>512) that would greatly
+ # increase startup time with limited performance benefit.
+ if not self.cuda_graph_sizes:
+ self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
+
+ if self.async_scheduling:
+ self.scheduler_cls = (
+ "vllm.v1.core.sched.async_scheduler.AsyncScheduler")
+
+ @model_validator(mode='after')
+ def _verify_args(self) -> Self:
+ if (self.max_num_batched_tokens < self.max_model_len
+ and not self.chunked_prefill_enabled):
+ raise ValueError(
+ f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
+ f"smaller than max_model_len ({self.max_model_len}). "
+ "This effectively limits the maximum sequence length to "
+ "max_num_batched_tokens and makes vLLM reject longer "
+ "sequences. Please increase max_num_batched_tokens or "
+ "decrease max_model_len.")
+
+ if self.max_num_batched_tokens < self.max_num_seqs:
+ raise ValueError(
+ f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
+ "be greater than or equal to max_num_seqs "
+ f"({self.max_num_seqs}).")
+
+ if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
+ logger.warning(
+ "max_num_batched_tokens (%d) exceeds max_num_seqs "
+ "* max_model_len (%d). This may lead to unexpected behavior.",
+ self.max_num_batched_tokens,
+ self.max_num_seqs * self.max_model_len)
+
+ if self.num_lookahead_slots < 0:
+ raise ValueError(
+ "num_lookahead_slots "
+ f"({self.num_lookahead_slots}) must be greater than or "
+ "equal to 0.")
+
+ if self.num_scheduler_steps < 1:
+ raise ValueError(
+ "num_scheduler_steps "
+ f"({self.num_scheduler_steps}) must be greater than or "
+ "equal to 1.")
+
+ if self.max_num_partial_prefills < 1:
+ raise ValueError(
+ f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
+ "must be greater than or equal to 1.")
+ elif self.max_num_partial_prefills > 1:
+ if not self.chunked_prefill_enabled:
+ raise ValueError("Chunked prefill must be enabled to set "
+ "max_num_partial_prefills > 1.")
+
+ if self.long_prefill_token_threshold > self.max_model_len:
+ raise ValueError(
+ "long_prefill_token_threshold "
+ f"({self.long_prefill_token_threshold}) cannot be greater "
+ f"than the max_model_len ({self.max_model_len}).")
+
+ if (self.max_long_partial_prefills
+ < 1) or (self.max_long_partial_prefills
+ > self.max_num_partial_prefills):
+ raise ValueError(
+ f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
+ "must be greater than or equal to 1 and less than or equal to "
+ f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
+
+ return self
+
+ @property
+ def is_multi_step(self) -> bool:
+ return self.num_scheduler_steps > 1
diff --git a/vllm/config/utils.py b/vllm/config/utils.py
new file mode 100644
index 0000000000000..98fbeb1fa86aa
--- /dev/null
+++ b/vllm/config/utils.py
@@ -0,0 +1,29 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import TYPE_CHECKING, TypeVar
+
+if TYPE_CHECKING:
+ from _typeshed import DataclassInstance
+
+ ConfigType = type[DataclassInstance]
+else:
+ ConfigType = type
+
+ConfigT = TypeVar("ConfigT", bound=ConfigType)
+
+
+def config(cls: ConfigT) -> ConfigT:
+ """
+ A decorator that ensures all fields in a dataclass have default values
+ and that each field has a docstring.
+
+ If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
+ provided by `get_kwargs` will be
+ `pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
+ `cli_arg` as a JSON string which gets validated by `pydantic`.
+
+ Config validation is performed by the tools/validate_config.py
+ script, which is invoked during the pre-commit checks.
+ """
+ return cls
diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py
index 4ab8f3d938fcf..66d4940c9cec5 100644
--- a/vllm/distributed/device_communicators/cuda_communicator.py
+++ b/vllm/distributed/device_communicators/cuda_communicator.py
@@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
- assert input_.shape[dim] == sizes[self.rank_in_group]
+ assert input_.shape[dim] == sizes[self.rank_in_group], (
+ f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py
index c415d409f7fef..979f2a06cec9f 100644
--- a/vllm/distributed/eplb/eplb_state.py
+++ b/vllm/distributed/eplb/eplb_state.py
@@ -259,7 +259,6 @@ class EplbState:
if global_expert_load is not None:
ep_group = get_ep_group().device_group
- assert ep_group is not None
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
@@ -366,7 +365,6 @@ class EplbState:
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group
- assert ep_group is not None
all_reduce(total_expert_load_pass, group=ep_group)
# num_tokens_per_rank: (num_moe_layers, num_ranks)
@@ -422,7 +420,6 @@ class EplbState:
"""
ep_group = get_ep_group().device_group
- assert ep_group is not None
ep_rank = ep_group.rank()
time_start = None
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index 01673a0d7c876..584fc1d655951 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -4,13 +4,17 @@
import importlib
from typing import TYPE_CHECKING, Callable
+# yapf: disable
import vllm.envs as envs
-from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.kv_transfer.kv_connector.base import (
+ KVConnectorBase, KVConnectorBaseType)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger
+# yapf: enable
+
if TYPE_CHECKING:
- from vllm.config import VllmConfig
+ from vllm.config import KVTransferConfig, VllmConfig
logger = init_logger(__name__)
@@ -42,17 +46,7 @@ class KVConnectorFactory:
f"but found {envs.VLLM_USE_V1=}")
kv_transfer_config = config.kv_transfer_config
- connector_name = kv_transfer_config.kv_connector
- if connector_name in cls._registry:
- connector_cls = cls._registry[connector_name]()
- else:
- connector_module_path = kv_transfer_config.kv_connector_module_path
- if connector_module_path is None:
- raise ValueError(
- f"Unsupported connector type: {connector_name}")
- connector_module = importlib.import_module(connector_module_path)
- connector_cls = getattr(connector_module, connector_name)
- assert issubclass(connector_cls, KVConnectorBase)
+ connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@@ -65,6 +59,23 @@ class KVConnectorFactory:
# We build separately to enforce strict separation
return connector_cls(config, role)
+ @classmethod
+ def get_connector_class(
+ cls, kv_transfer_config: "KVTransferConfig"
+ ) -> type[KVConnectorBaseType]:
+ """Get the connector class by name."""
+ connector_name = kv_transfer_config.kv_connector
+ if connector_name in cls._registry:
+ connector_cls = cls._registry[connector_name]()
+ else:
+ connector_module_path = kv_transfer_config.kv_connector_module_path
+ if connector_module_path is None:
+ raise ValueError(
+ f"Unsupported connector type: {connector_name}")
+ connector_module = importlib.import_module(connector_module_path)
+ connector_cls = getattr(connector_module, connector_name)
+ return connector_cls
+
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index 1da41790f9fb1..2364400b3d350 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -13,8 +13,8 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
-from vllm.distributed.kv_transfer.kv_connector.v1.base import (
- KVConnectorBase_V1)
+from vllm.distributed.kv_transfer.kv_connector.factory import (
+ KVConnectorFactory)
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@@ -106,8 +106,9 @@ def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
- required_kvcache_layout = (
- KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
+ connector_cls = KVConnectorFactory.get_connector_class(kv_config)
+ required_kvcache_layout = connector_cls.get_required_kvcache_layout(
+ vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once("Connectors do not specify a " \
@@ -143,6 +144,8 @@ class KVOutputAggregator:
finished_recving = set[str]()
for output in outputs:
output = output.kv_connector_output
+ if not output:
+ continue
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
index 7a2ccb58656fd..b72104397822b 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
@@ -12,6 +12,8 @@ The class provides the following primitives:
times for a given request and should be side-effect free.
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
+ update_connector_output() - update KVConnector state after
+ output is received from worker-side connectors.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
@@ -38,6 +40,7 @@ import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -283,6 +286,16 @@ class KVConnectorBase_V1(ABC):
"""
pass
+ def update_connector_output(self, connector_output: KVConnectorOutput):
+ """
+ Update KVConnector state from worker-side connectors output.
+
+ Args:
+ connector_output (KVConnectorOutput): the worker-side
+ connectors output.
+ """
+ return
+
def request_finished(
self,
request: "Request",
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
index 62a4980bff975..7d67c76e2f052 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -177,6 +178,10 @@ class MultiConnector(KVConnectorBase_V1):
self._extra_async_saves = {}
return metadata
+ def update_connector_output(self, connector_output: KVConnectorOutput):
+ for c in self._connectors:
+ c.update_connector_output(connector_output)
+
def request_finished(
self,
request: "Request",
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index e7fc2b118145c..a6eeb278532ee 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -29,7 +29,7 @@ from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
-from vllm.utils import make_zmq_path, make_zmq_socket, round_down
+from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
@@ -275,10 +275,7 @@ class NixlConnectorScheduler:
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
- assert num_computed_tokens % self.block_size == 0
- rounded_num_prompt_tokens = round_down(
- len(request.prompt_token_ids), self.block_size)
- count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
+ count = len(request.prompt_token_ids) - num_computed_tokens
if count > 0:
return count, True
@@ -301,18 +298,16 @@ class NixlConnectorScheduler:
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
- # figure out full computed blocks to save
+ # save all blocks
block_ids = blocks.get_block_ids()[0]
- all_full = request.num_tokens % self.block_size == 0
- full_block_ids = (block_ids if all_full else block_ids[:-1])
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
- if full_block_ids:
+ if block_ids:
self._reqs_need_save[request.request_id] = \
- (request, full_block_ids)
+ (request, block_ids)
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
@@ -401,12 +396,9 @@ class NixlConnectorScheduler:
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
- # Get computed blocks.
- all_full = request.num_computed_tokens % self.block_size == 0
- computed_block_ids = block_ids if all_full else block_ids[:-1]
-
- # If prompt < block_size, no xfer so free blocks immediately.
- delay_free_blocks = len(computed_block_ids) > 0
+ # TODO: check whether block_ids actually ever be 0. If not we could
+ # remove the conditional below
+ delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
@@ -416,7 +408,7 @@ class NixlConnectorScheduler:
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
- remote_block_ids=computed_block_ids,
+ remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 6c25cdcfb7b8c..b89aee99c8d46 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -36,6 +36,7 @@ from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
+from typing_extensions import deprecated
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
@@ -196,11 +197,10 @@ class GroupCoordinator:
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
- cpu_group: Optional[ProcessGroup] # group for CPU communication
- device_group: Optional[ProcessGroup] # group for device communication
- use_device_communicator: bool # whether to use device communicator
- device_communicator: Optional[
- DeviceCommunicatorBase] # device communicator
+ cpu_group: ProcessGroup # group for CPU communication
+ device_group: ProcessGroup # group for device communication
+ # device communicator (if use_device_communicator=True)
+ device_communicator: Optional[DeviceCommunicatorBase]
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
@@ -208,7 +208,7 @@ class GroupCoordinator:
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
- use_device_communicator: bool,
+ use_device_communicator: bool, # whether to use device communicator
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
@@ -218,8 +218,9 @@ class GroupCoordinator:
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
- self.device_group = None
- self.cpu_group = None
+
+ self_device_group = None
+ self_cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
@@ -231,11 +232,14 @@ class GroupCoordinator:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
- self.device_group = device_group
- self.cpu_group = cpu_group
+ self_device_group = device_group
+ self_cpu_group = cpu_group
- assert self.cpu_group is not None
- assert self.device_group is not None
+ assert self_cpu_group is not None
+ assert self_device_group is not None
+
+ self.cpu_group = self_cpu_group
+ self.device_group = self_device_group
from vllm.platforms import current_platform
@@ -250,7 +254,6 @@ class GroupCoordinator:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
-
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
@@ -816,12 +819,12 @@ class GroupCoordinator:
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
- if self.device_group is not None:
+ if hasattr(self, "device_group"):
torch.distributed.destroy_process_group(self.device_group)
- self.device_group = None
- if self.cpu_group is not None:
+ del self.device_group
+ if hasattr(self, "cpu_group"):
torch.distributed.destroy_process_group(self.cpu_group)
- self.cpu_group = None
+ del self.cpu_group
if self.device_communicator is not None:
self.device_communicator.destroy()
if self.mq_broadcaster is not None:
@@ -894,8 +897,12 @@ def get_tp_group() -> GroupCoordinator:
return _TP
-# kept for backward compatibility
-get_tensor_model_parallel_group = get_tp_group
+@deprecated("`get_tensor_model_parallel_group` has been replaced with "
+ "`get_tp_group` and may be removed after v0.12. Please use "
+ "`get_tp_group` instead.")
+def get_tensor_model_parallel_group():
+ return get_tp_group()
+
_PP: Optional[GroupCoordinator] = None
@@ -921,8 +928,11 @@ def get_pp_group() -> GroupCoordinator:
return _PP
-# kept for backward compatibility
-get_pipeline_model_parallel_group = get_pp_group
+@deprecated("`get_pipeline_model_parallel_group` has been replaced with "
+ "`get_pp_group` and may be removed in v0.12. Please use "
+ "`get_pp_group` instead.")
+def get_pipeline_model_parallel_group():
+ return get_pp_group()
@contextmanager
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index c0ac3ff6317fe..d74db67bda0dc 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -39,6 +39,7 @@ from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
+from vllm.transformers_utils.config import is_interleaved
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
@@ -178,23 +179,12 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name] = {"default": default, "help": help}
# Set other kwargs based on the type hints
- json_tip = """Should either be a valid JSON string or JSON keys
-passed individually. For example, the following sets of arguments are
-equivalent:
-
-- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
-- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
-
-Additionally, list elements can be passed individually using `+`:
-
-- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
-- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`"""
+ json_tip = ("Should either be a valid JSON string or JSON keys passed "
+ "individually.")
if dataclass_cls is not None:
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
try:
- if hasattr(cls, "from_cli"):
- return cls.from_cli(val)
return TypeAdapter(cls).validate_json(val)
except ValidationError as e:
raise argparse.ArgumentTypeError(repr(e)) from e
@@ -455,9 +445,9 @@ class EngineArgs:
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# CompilationConfig object
- if isinstance(self.compilation_config, (int, dict)):
- self.compilation_config = CompilationConfig.from_cli(
- str(self.compilation_config))
+ if isinstance(self.compilation_config, dict):
+ self.compilation_config = CompilationConfig(
+ **self.compilation_config)
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
@@ -836,6 +826,10 @@ class EngineArgs:
title="VllmConfig",
description=VllmConfig.__doc__,
)
+ # We construct SpeculativeConfig using fields from other configs in
+ # create_engine_config. So we set the type to a JSON string here to
+ # delay the Pydantic validation that comes with SpeculativeConfig.
+ vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
vllm_group.add_argument("--speculative-config",
**vllm_kwargs["speculative_config"])
vllm_group.add_argument("--kv-transfer-config",
@@ -1092,6 +1086,13 @@ class EngineArgs:
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
+ sliding_window: Optional[int] = None
+ if not is_interleaved(model_config.hf_text_config):
+ # Only set CacheConfig.sliding_window if the model is all sliding
+ # window. Otherwise CacheConfig.sliding_window will override the
+ # global layers in interleaved sliding window models.
+ sliding_window = model_config.get_sliding_window()
+
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
@@ -1099,7 +1100,7 @@ class EngineArgs:
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
- sliding_window=model_config.get_sliding_window(),
+ sliding_window=sliding_window,
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
@@ -1603,11 +1604,10 @@ class EngineArgs:
else:
pooling_type = model_config.pooler_config.pooling_type
-
- # TODO: when encoder models are supported we'll have to
- # check for causal attention here.
- incremental_prefill_supported = (pooling_type is not None and
- pooling_type.lower() == "last")
+ is_causal = getattr(model_config.hf_config, "is_causal", True)
+ incremental_prefill_supported = (pooling_type is not None
+ and pooling_type.lower() == "last"
+ and is_causal)
action = "Enabling" if \
incremental_prefill_supported else "Disabling"
@@ -1833,13 +1833,3 @@ def human_readable_int(value):
# Regular plain number.
return int(value)
-
-
-# These functions are used by sphinx to build the documentation
-def _engine_args_parser():
- return EngineArgs.add_cli_args(FlexibleArgumentParser())
-
-
-def _async_engine_args_parser():
- return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
- async_args_only=True)
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 79255b031eeca..3fc4f6445df2a 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -845,7 +845,8 @@ class LLMEngine:
def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
- return self.input_preprocessor.mm_registry.reset_processor_cache()
+ return self.input_preprocessor.mm_registry.reset_processor_cache(
+ self.model_config)
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""
diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py
index e71f77ba80673..7c01de94a3436 100644
--- a/vllm/entrypoints/cli/openai.py
+++ b/vllm/entrypoints/cli/openai.py
@@ -130,28 +130,33 @@ class ChatCommand(CLISubcommand):
conversation.append(response_message) # type: ignore
print(output)
- def subparser_init(
- self,
- subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
- chat_parser = subparsers.add_parser(
- "chat",
- help="Generate chat completions via the running API server.",
- description="Generate chat completions via the running API server.",
- usage="vllm chat [options]")
- _add_query_options(chat_parser)
- chat_parser.add_argument(
+ @staticmethod
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
+ """Add CLI arguments for the chat command."""
+ _add_query_options(parser)
+ parser.add_argument(
"--system-prompt",
type=str,
default=None,
help=("The system prompt to be added to the chat template, "
"used for models that support system prompts."))
- chat_parser.add_argument("-q",
- "--quick",
- type=str,
- metavar="MESSAGE",
- help=("Send a single prompt as MESSAGE "
- "and print the response, then exit."))
- return chat_parser
+ parser.add_argument("-q",
+ "--quick",
+ type=str,
+ metavar="MESSAGE",
+ help=("Send a single prompt as MESSAGE "
+ "and print the response, then exit."))
+ return parser
+
+ def subparser_init(
+ self,
+ subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
+ parser = subparsers.add_parser(
+ "chat",
+ help="Generate chat completions via the running API server.",
+ description="Generate chat completions via the running API server.",
+ usage="vllm chat [options]")
+ return ChatCommand.add_cli_args(parser)
class CompleteCommand(CLISubcommand):
@@ -179,25 +184,30 @@ class CompleteCommand(CLISubcommand):
output = completion.choices[0].text
print(output)
- def subparser_init(
- self,
- subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
- complete_parser = subparsers.add_parser(
- "complete",
- help=("Generate text completions based on the given prompt "
- "via the running API server."),
- description=("Generate text completions based on the given prompt "
- "via the running API server."),
- usage="vllm complete [options]")
- _add_query_options(complete_parser)
- complete_parser.add_argument(
+ @staticmethod
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
+ """Add CLI arguments for the complete command."""
+ _add_query_options(parser)
+ parser.add_argument(
"-q",
"--quick",
type=str,
metavar="PROMPT",
help=
"Send a single prompt and print the completion output, then exit.")
- return complete_parser
+ return parser
+
+ def subparser_init(
+ self,
+ subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
+ parser = subparsers.add_parser(
+ "complete",
+ help=("Generate text completions based on the given prompt "
+ "via the running API server."),
+ description=("Generate text completions based on the given prompt "
+ "via the running API server."),
+ usage="vllm complete [options]")
+ return CompleteCommand.add_cli_args(parser)
def cmd_init() -> list[CLISubcommand]:
diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py
index 6292306e7cdbe..e817f07ef5947 100644
--- a/vllm/entrypoints/context.py
+++ b/vllm/entrypoints/context.py
@@ -1,15 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import json
import logging
from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Union
-from openai_harmony import Message, Role, StreamState
+from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import (
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
from vllm.entrypoints.tool import Tool
from vllm.outputs import RequestOutput
+if TYPE_CHECKING:
+ from mcp.client import ClientSession
+
logger = logging.getLogger(__name__)
@@ -71,6 +76,7 @@ class HarmonyContext(ConversationContext):
def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
output_token_ids = output.outputs[0].token_ids
+ self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
output_msgs = self.parser.messages
@@ -106,19 +112,41 @@ class HarmonyContext(ConversationContext):
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
- async def call_search_tool(
- self,
- tool_session: Tool,
- last_msg: Message,
- ) -> list[Message]:
- return await tool_session.get_result(self)
+ async def call_search_tool(self, tool_session: Union["ClientSession",
+ Tool],
+ last_msg: Message) -> list[Message]:
+ if isinstance(tool_session, Tool):
+ return await tool_session.get_result(self)
+ tool_name = last_msg.recipient.split(".")[1]
+ args = json.loads(last_msg.content[0].text)
+ result = await tool_session.call_tool(tool_name, args)
+ result_str = result.content[0].text
+ content = TextContent(text=result_str)
+ author = Author(role=Role.TOOL, name=last_msg.recipient)
+ return [
+ Message(author=author, content=[content], recipient=Role.ASSISTANT)
+ ]
- async def call_python_tool(
- self,
- tool_session: Tool,
- last_msg: Message,
- ) -> list[Message]:
- return await tool_session.get_result(self)
+ async def call_python_tool(self, tool_session: Union["ClientSession",
+ Tool],
+ last_msg: Message) -> list[Message]:
+ if isinstance(tool_session, Tool):
+ return await tool_session.get_result(self)
+ param = {
+ "code": last_msg.content[0].text,
+ }
+ result = await tool_session.call_tool("python", param)
+ result_str = result.content[0].text
+
+ content = TextContent(text=result_str)
+ author = Author(role=Role.TOOL, name="python")
+
+ return [
+ Message(author=author,
+ content=[content],
+ channel=last_msg.channel,
+ recipient=Role.ASSISTANT)
+ ]
class StreamingHarmonyContext(HarmonyContext):
diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py
index 87e76e08a0b44..efca1472e44cf 100644
--- a/vllm/entrypoints/harmony_utils.py
+++ b/vllm/entrypoints/harmony_utils.py
@@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
- text=content.text,
+ content=[
+ ResponseReasoningTextContent(text=content.text,
+ type="reasoning_text")
+ ],
status=None,
)
output_items.append(reasoning_item)
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index ca24b0c32b73b..915f14a29b907 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
+# yapf conflicts with isort for this block
+# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
+ compress_token_type_ids,
get_score_prompt)
+# yapf: enable
from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
@@ -1096,6 +1100,10 @@ class LLM:
"Try passing `--runner pooling` to use the model as a "
"pooling model.")
+ if pooling_task not in self.supported_tasks:
+ raise ValueError(
+ f"pooling_task must be one of {self.supported_tasks}.")
+
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, list[str]]], prompts),
@@ -1329,6 +1337,7 @@ class LLM:
model_config = self.llm_engine.model_config
pooling_params.verify("score", model_config)
+ pooling_params_list = list[PoolingParams]()
tokenization_kwargs: dict[str, Any] = {}
@@ -1339,38 +1348,31 @@ class LLM:
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
- if model_config.is_multimodal_model:
- for q, d in input_pairs:
- _, engine_prompt = get_score_prompt(
- model_config=model_config,
- data_1=q,
- data_2=d,
- tokenizer=tokenizer,
- tokenization_kwargs=tokenization_kwargs,
- )
+ model_config = self.llm_engine.model_config
- parsed_prompts.append(engine_prompt)
- else:
- for q, t in input_pairs:
- if model_config.use_pad_token:
- # cross_encoder models defaults to using pad_token.
- prompt_inputs = tokenizer(
- text=q, # type: ignore[arg-type]
- text_pair=t, # type: ignore[arg-type]
- **tokenization_kwargs)
- else:
- # `llm as reranker` models defaults to not using pad_token.
- prompt_inputs = tokenizer(
- text=q + t, # type: ignore[operator]
- **tokenization_kwargs)
- engine_prompt = TokensPrompt(
- prompt_token_ids=prompt_inputs["input_ids"],
- token_type_ids=prompt_inputs.get("token_type_ids"))
- parsed_prompts.append(engine_prompt)
+ for q, d in input_pairs:
+ _, engine_prompt = get_score_prompt(
+ model_config=model_config,
+ data_1=q,
+ data_2=d,
+ tokenizer=tokenizer,
+ tokenization_kwargs=tokenization_kwargs,
+ )
+
+ if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
+ "token_type_ids", None)):
+ params = pooling_params.clone()
+ compressed = compress_token_type_ids(token_type_ids)
+ params.extra_kwargs = {"compressed_token_type_ids": compressed}
+ pooling_params_list.append(params)
+ else:
+ pooling_params_list.append(pooling_params)
+
+ parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
- params=pooling_params,
+ params=pooling_params_list,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index c695ea8b5a0ef..e5d31c1fd03fa 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -94,7 +94,8 @@ from vllm.entrypoints.openai.serving_tokenization import (
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription, OpenAIServingTranslation)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
-from vllm.entrypoints.tool_server import DemoToolServer, ToolServer
+from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
+ ToolServer)
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
log_non_default_args, with_cancellation)
from vllm.logger import init_logger
@@ -1635,6 +1636,9 @@ async def init_app_state(
if args.tool_server == "demo":
tool_server: Optional[ToolServer] = DemoToolServer()
+ elif args.tool_server:
+ tool_server = MCPToolServer()
+ await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None
@@ -1773,6 +1777,12 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock
+def create_server_unix_socket(path: str) -> socket.socket:
+ sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+ sock.bind(path)
+ return sock
+
+
def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
@@ -1803,8 +1813,11 @@ def setup_server(args):
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
- sock_addr = (args.host or "", args.port)
- sock = create_server_socket(sock_addr)
+ if args.uds:
+ sock = create_server_unix_socket(args.uds)
+ else:
+ sock_addr = (args.host or "", args.port)
+ sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
@@ -1816,12 +1829,14 @@ def setup_server(args):
signal.signal(signal.SIGTERM, signal_handler)
- addr, port = sock_addr
- is_ssl = args.ssl_keyfile and args.ssl_certfile
- host_part = f"[{addr}]" if is_valid_ipv6_address(
- addr) else addr or "0.0.0.0"
- listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
-
+ if args.uds:
+ listen_address = f"unix:{args.uds}"
+ else:
+ addr, port = sock_addr
+ is_ssl = args.ssl_keyfile and args.ssl_certfile
+ host_part = f"[{addr}]" if is_valid_ipv6_address(
+ addr) else addr or "0.0.0.0"
+ listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index e89463a03cdae..e15f65b43082c 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -72,6 +72,8 @@ class FrontendArgs:
"""Host name."""
port: int = 8000
"""Port number."""
+ uds: Optional[str] = None
+ """Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical",
"trace"] = "info"
"""Log level for uvicorn."""
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 3b9f4b544e45c..543701ed144ee 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -19,8 +19,8 @@ from openai.types.chat.chat_completion_message import (
# yapf: enable
from openai.types.responses import (ResponseFunctionToolCall,
ResponseInputItemParam, ResponseOutputItem,
- ResponsePrompt, ResponseStatus,
- ResponseTextConfig)
+ ResponsePrompt, ResponseReasoningItem,
+ ResponseStatus, ResponseTextConfig)
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
@@ -239,6 +239,7 @@ def get_logits_processors(processors: Optional[LogitsProcessors],
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
+ ResponseReasoningItem,
ResponseFunctionToolCall]
diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py
index a10d57456ba08..01551a8c7f04a 100644
--- a/vllm/entrypoints/openai/run_batch.py
+++ b/vllm/entrypoints/openai/run_batch.py
@@ -20,7 +20,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
-from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
@@ -34,7 +33,6 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.logger import init_logger
-from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
@@ -469,6 +467,9 @@ async def run_batch(
async def main(args: Namespace):
+ from vllm.entrypoints.openai.api_server import build_async_engine_client
+ from vllm.usage.usage_lib import UsageContext
+
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index efd2f20299d09..fb9d456df78e9 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
-import base64
import io
import json
import sys
@@ -12,6 +11,7 @@ from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypeVar, Union, cast, overload)
+import pybase64
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
@@ -1008,7 +1008,8 @@ class OpenAIServing:
) -> list[EmbedsPrompt]:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
- tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
+ tensor = torch.load(io.BytesIO(
+ pybase64.b64decode(embed, validate=True)),
weights_only=True)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py
index a7554e0d68311..86c16df40e693 100644
--- a/vllm/entrypoints/openai/serving_responses.py
+++ b/vllm/entrypoints/openai/serving_responses.py
@@ -2,17 +2,30 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
+import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
+from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union
import jinja2
+import openai.types.responses as openai_responses_types
from fastapi import Request
-from openai.types.responses import (ResponseFunctionToolCall,
- ResponseOutputItem, ResponseOutputMessage,
- ResponseOutputText, ResponseReasoningItem)
+from openai import BaseModel
+# yapf conflicts with isort for this block
+# yapf: disable
+from openai.types.responses import (ResponseCreatedEvent,
+ ResponseFunctionToolCall,
+ ResponseInProgressEvent,
+ ResponseOutputItem,
+ ResponseOutputItemDoneEvent,
+ ResponseOutputMessage, ResponseOutputText,
+ ResponseReasoningItem,
+ ResponseReasoningTextDeltaEvent,
+ ResponseReasoningTextDoneEvent)
+# yapf: enable
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent)
from openai_harmony import Message as OpenAIHarmonyMessage
@@ -40,7 +53,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
-from vllm.entrypoints.tool_server import ToolServer
+from vllm.entrypoints.tool_server import MCPToolServer, ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput
@@ -224,67 +237,121 @@ class OpenAIServingResponses(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
+ if self.tool_server is not None and isinstance(
+ self.tool_server, MCPToolServer
+ ) and (request.background or request.stream) and request.tools and any(
+ tool.type in ["web_search_preview", "code_interpreter"]
+ for tool in request.tools):
+ return self.create_error_response(
+ "MCP tool server is not supported in background mode and "
+ "streaming mode")
+
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[ConversationContext, None]] = []
- try:
- tool_sessions: dict[str, Any] = {}
- for i, engine_prompt in enumerate(engine_prompts):
- default_max_tokens = self.max_model_len - len(
- engine_prompt["prompt_token_ids"])
- sampling_params = request.to_sampling_params(
- default_max_tokens, self.default_sampling_params)
- trace_headers = (None if raw_request is None else await
- self._get_trace_headers(raw_request.headers))
-
- context: ConversationContext
- if self.use_harmony:
- if request.stream:
- context = StreamingHarmonyContext(
- messages, tool_sessions)
- else:
- context = HarmonyContext(messages, tool_sessions)
+ builtin_tool_list: list[str] = []
+ if self.use_harmony and self.tool_server is not None:
+ if self.tool_server.has_tool("browser"):
+ builtin_tool_list.append("browser")
+ if self.tool_server.has_tool("python"):
+ builtin_tool_list.append("python")
+ async with AsyncExitStack() as exit_stack:
+ try:
+ if self.tool_server is not None:
+ # TODO: initialize tool sessions lazily when the session
+ # is actually used.
+ tool_session_ctxs: dict[str, Any] = {
+ tool_name:
+ exit_stack.enter_async_context(
+ self.tool_server.new_session(tool_name))
+ for tool_name in builtin_tool_list
+ }
+ tool_sessions = {}
+ for tool_name in builtin_tool_list:
+ tool_sessions[tool_name] = (
+ await tool_session_ctxs[tool_name])
else:
- context = SimpleContext()
- generator = self._generate_with_builtin_tools(
- request_id=request.request_id,
- request_prompt=request_prompts[i],
- engine_prompt=engine_prompt,
- sampling_params=sampling_params,
- context=context,
- lora_request=lora_request,
- priority=request.priority,
- trace_headers=trace_headers,
+ assert len(builtin_tool_list) == 0
+ tool_sessions = {}
+ for i, engine_prompt in enumerate(engine_prompts):
+ default_max_tokens = self.max_model_len - len(
+ engine_prompt["prompt_token_ids"])
+ sampling_params = request.to_sampling_params(
+ default_max_tokens, self.default_sampling_params)
+
+ trace_headers = (None if raw_request is None else await
+ self._get_trace_headers(
+ raw_request.headers))
+
+ context: ConversationContext
+ if self.use_harmony:
+ if request.stream:
+ context = StreamingHarmonyContext(
+ messages, tool_sessions)
+ else:
+ context = HarmonyContext(messages, tool_sessions)
+ else:
+ context = SimpleContext()
+ generator = self._generate_with_builtin_tools(
+ request_id=request.request_id,
+ request_prompt=request_prompts[i],
+ engine_prompt=engine_prompt,
+ sampling_params=sampling_params,
+ context=context,
+ lora_request=lora_request,
+ priority=request.priority,
+ trace_headers=trace_headers,
+ )
+ generators.append(generator)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
+
+ assert len(generators) == 1
+ result_generator, = generators
+
+ # Store the input messages.
+ if request.store:
+ self.msg_store[request.request_id] = messages
+
+ if request.background:
+ created_time = int(time.time())
+ response = ResponsesResponse.from_request(
+ request,
+ sampling_params,
+ model_name=model_name,
+ created_time=created_time,
+ output=[],
+ status="queued",
+ usage=None,
)
- generators.append(generator)
- except ValueError as e:
- # TODO: Use a vllm-specific Validation Error
- return self.create_error_response(str(e))
+ async with self.response_store_lock:
+ self.response_store[response.id] = response
- assert len(generators) == 1
- result_generator, = generators
+ # Run the request in the background.
+ task = asyncio.create_task(
+ self._run_background_request(
+ request,
+ sampling_params,
+ result_generator,
+ context,
+ model_name,
+ tokenizer,
+ request_metadata,
+ created_time,
+ ),
+ name=f"create_{response.id}",
+ )
- # Store the input messages.
- if request.store:
- self.msg_store[request.request_id] = messages
+ # For cleanup.
+ response_id = response.id
+ self.background_tasks[response_id] = task
+ task.add_done_callback(
+ lambda _: self.background_tasks.pop(response_id, None))
+ return response
- if request.background:
- created_time = int(time.time())
- response = ResponsesResponse.from_request(
- request,
- sampling_params,
- model_name=model_name,
- created_time=created_time,
- output=[],
- status="queued",
- usage=None,
- )
- async with self.response_store_lock:
- self.response_store[response.id] = response
-
- # Run the request in the background.
- task = asyncio.create_task(
- self._run_background_request(
+ if request.stream:
+ return self.responses_stream_generator(
request,
sampling_params,
result_generator,
@@ -292,33 +359,21 @@ class OpenAIServingResponses(OpenAIServing):
model_name,
tokenizer,
request_metadata,
- created_time,
- ),
- name=f"create_{response.id}",
- )
+ )
- # For cleanup.
- response_id = response.id
- self.background_tasks[response_id] = task
- task.add_done_callback(
- lambda _: self.background_tasks.pop(response_id, None))
- return response
-
- if request.stream:
- raise NotImplementedError("Streaming responses are not supported")
-
- try:
- return await self.responses_full_generator(
- request,
- sampling_params,
- result_generator,
- context,
- model_name,
- tokenizer,
- request_metadata,
- )
- except Exception as e:
- return self.create_error_response(str(e))
+ try:
+ return await self.responses_full_generator(
+ request,
+ sampling_params,
+ result_generator,
+ context,
+ model_name,
+ tokenizer,
+ request_metadata,
+ )
+ except Exception as e:
+ return self.create_error_response(str(e))
+ return self.create_error_response("Should not reach here")
async def _make_request(
self,
@@ -717,3 +772,418 @@ class OpenAIServingResponses(OpenAIServing):
"starting the vLLM server."),
status_code=HTTPStatus.BAD_REQUEST,
)
+
+ async def responses_stream_generator(
+ self,
+ request: ResponsesRequest,
+ sampling_params: SamplingParams,
+ result_generator: AsyncIterator[Optional[ConversationContext]],
+ context: ConversationContext,
+ model_name: str,
+ tokenizer: AnyTokenizer,
+ request_metadata: RequestResponseMetadata,
+ created_time: Optional[int] = None,
+ ) -> AsyncGenerator[str, None]:
+ # TODO:
+ # 1. Handle disconnect
+
+ if not isinstance(context, StreamingHarmonyContext):
+ raise NotImplementedError(
+ "Streaming is not supported for responses API without Harmony."
+ )
+
+ created_time = created_time or int(time.time())
+
+ sequence_number = 0
+
+ def _send_event(event: BaseModel):
+ nonlocal sequence_number
+ # Set sequence_number if the event has this attribute
+ if hasattr(event, 'sequence_number'):
+ event.sequence_number = sequence_number
+ sequence_number += 1
+ # Get event type from the event's type field if it exists
+ event_type = getattr(event, 'type', 'unknown')
+ return (f"event: {event_type}\n"
+ f"data: {event.model_dump_json(indent=None)}\n\n")
+
+ current_content_index = 0 # FIXME: this number is never changed
+ current_output_index = 0
+ current_item_id = "" # FIXME: this number is never changed
+ sent_output_item_added = False
+
+ initial_response = ResponsesResponse.from_request(
+ request,
+ sampling_params,
+ model_name=model_name,
+ created_time=created_time,
+ output=[],
+ status="in_progress",
+ usage=None,
+ ).model_dump()
+ yield _send_event(
+ ResponseCreatedEvent(
+ type="response.created",
+ sequence_number=-1,
+ response=initial_response,
+ ))
+ yield _send_event(
+ ResponseInProgressEvent(
+ type="response.in_progress",
+ sequence_number=-1,
+ response=initial_response,
+ ))
+
+ async for ctx in result_generator:
+
+ assert isinstance(ctx, StreamingHarmonyContext)
+
+ if ctx.is_expecting_start():
+ current_output_index += 1
+ sent_output_item_added = False
+
+ if len(ctx.parser.messages) > 0:
+ previous_item = ctx.parser.messages[-1]
+ if previous_item.recipient is not None:
+ # Deal with tool call here
+ pass
+ elif previous_item.channel == "analysis":
+ reasoning_item = ResponseReasoningItem(
+ type="reasoning",
+ content=[
+ ResponseReasoningTextContent(
+ text=previous_item.content[0].text,
+ type="reasoning_text",
+ ),
+ ],
+ status="completed",
+ id=current_item_id,
+ summary=[],
+ )
+ yield _send_event(
+ ResponseReasoningTextDoneEvent(
+ type="response.reasoning_text.done",
+ item_id=current_item_id,
+ sequence_number=-1,
+ output_index=current_output_index,
+ content_index=current_content_index,
+ text=previous_item.content[0].text,
+ ))
+ yield _send_event(
+ ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=reasoning_item,
+ ))
+ elif previous_item.channel == "final":
+ text_content = ResponseOutputText(
+ type="output_text",
+ text=previous_item.content[0].text,
+ annotations=[],
+ )
+ yield _send_event(
+ openai_responses_types.ResponseTextDoneEvent(
+ type="response.output_text.done",
+ sequence_number=-1,
+ output_index=current_output_index,
+ content_index=current_content_index,
+ text=previous_item.content[0].text,
+ logprobs=[],
+ item_id=current_item_id,
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseContentPartDoneEvent(
+ type="response.content_part.done",
+ sequence_number=-1,
+ item_id=current_item_id,
+ output_index=current_output_index,
+ content_index=current_content_index,
+ part=text_content,
+ ))
+ yield _send_event(
+ openai_responses_types.ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=ResponseOutputMessage(
+ id=current_item_id,
+ type="message",
+ role="assistant",
+ content=[text_content],
+ status="completed",
+ ),
+ ))
+
+ if ctx.parser.last_content_delta:
+ if (ctx.parser.current_channel == "final"
+ and ctx.parser.current_recipient is None):
+ if not sent_output_item_added:
+ sent_output_item_added = True
+ yield _send_event(
+ openai_responses_types.
+ ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=openai_responses_types.
+ ResponseOutputMessage(
+ id=current_item_id,
+ type="message",
+ role="assistant",
+ content=[],
+ status="in_progress",
+ ),
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseContentPartAddedEvent(
+ type="response.content_part.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ content_index=current_content_index,
+ part=openai_responses_types.ResponseOutputText(
+ type="output_text",
+ text="",
+ annotations=[],
+ logprobs=[],
+ ),
+ ))
+ yield _send_event(
+ openai_responses_types.ResponseTextDeltaEvent(
+ type="response.output_text.delta",
+ sequence_number=-1,
+ content_index=current_content_index,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ delta=ctx.parser.last_content_delta,
+ # TODO, use logprobs from ctx.last_request_output
+ logprobs=[],
+ ))
+ elif (ctx.parser.current_channel == "analysis"
+ and ctx.parser.current_recipient is None):
+ if not sent_output_item_added:
+ sent_output_item_added = True
+ yield _send_event(
+ openai_responses_types.
+ ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=openai_responses_types.
+ ResponseReasoningItem(
+ type="reasoning",
+ id=current_item_id,
+ summary=[],
+ status="in_progress",
+ ),
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseContentPartAddedEvent(
+ type="response.content_part.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ content_index=current_content_index,
+ part=openai_responses_types.ResponseOutputText(
+ type="output_text",
+ text="",
+ annotations=[],
+ logprobs=[],
+ ),
+ ))
+ yield _send_event(
+ ResponseReasoningTextDeltaEvent(
+ type="response.reasoning_text.delta",
+ item_id=current_item_id,
+ output_index=current_output_index,
+ content_index=current_content_index,
+ delta=ctx.parser.last_content_delta,
+ sequence_number=-1,
+ ))
+
+ if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:
+ previous_item = ctx.parser.messages[-1]
+ if (self.tool_server is not None
+ and self.tool_server.has_tool("browser")
+ and previous_item.recipient is not None
+ and previous_item.recipient.startswith("browser.")):
+ function_name = previous_item.recipient[len("browser."):]
+ action = None
+ parsed_args = json.loads(previous_item.content[0].text)
+ if function_name == "search":
+ action = (openai_responses_types.
+ response_function_web_search.ActionSearch(
+ type="search",
+ query=parsed_args["query"],
+ ))
+ elif function_name == "open":
+ action = (
+ openai_responses_types.
+ response_function_web_search.ActionOpenPage(
+ type="open_page",
+ # TODO: translate to url
+ url=f"cursor:{parsed_args.get('cursor', '')}",
+ ))
+ elif function_name == "find":
+ action = (
+ openai_responses_types.
+ response_function_web_search.ActionFind(
+ type="find",
+ pattern=parsed_args["pattern"],
+ # TODO: translate to url
+ url=f"cursor:{parsed_args.get('cursor', '')}",
+ ))
+ else:
+ raise ValueError(
+ f"Unknown function name: {function_name}")
+
+ yield _send_event(
+ openai_responses_types.ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=openai_responses_types.
+ response_function_web_search.
+ ResponseFunctionWebSearch(
+ # TODO: generate a unique id for web search call
+ type="web_search_call",
+ id=current_item_id,
+ action=action,
+ status="in_progress",
+ ),
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseWebSearchCallInProgressEvent(
+ type="response.web_search_call.in_progress",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseWebSearchCallSearchingEvent(
+ type="response.web_search_call.searching",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ ))
+
+ # enqueue
+ yield _send_event(
+ openai_responses_types.
+ ResponseWebSearchCallCompletedEvent(
+ type="response.web_search_call.completed",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ ))
+ yield _send_event(
+ openai_responses_types.ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=openai_responses_types.
+ ResponseFunctionWebSearch(
+ type="web_search_call",
+ id=current_item_id,
+ action=action,
+ status="completed",
+ ),
+ ))
+
+ if (self.tool_server is not None
+ and self.tool_server.has_tool("python")
+ and previous_item.recipient is not None
+ and previous_item.recipient.startswith("python")):
+ yield _send_event(
+ openai_responses_types.ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=openai_responses_types.
+ ResponseCodeInterpreterToolCallParam(
+ type="code_interpreter_call",
+ id=current_item_id,
+ code="",
+ container_id="auto",
+ outputs=[],
+ status="in_progress",
+ ),
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseCodeInterpreterCallInProgressEvent(
+ type="response.code_interpreter_call.in_progress",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ ))
+ # TODO: do we need to add delta event here?
+ yield _send_event(
+ openai_responses_types.
+ ResponseCodeInterpreterCallCodeDoneEvent(
+ type="response.code_interpreter_call_code.done",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ code=previous_item.content[0].text))
+ yield _send_event(
+ openai_responses_types.
+ ResponseCodeInterpreterCallInterpretingEvent(
+ type="response.code_interpreter_call.interpreting",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ ))
+ yield _send_event(
+ openai_responses_types.
+ ResponseCodeInterpreterCallCompletedEvent(
+ type="response.code_interpreter_call.completed",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ ))
+ yield _send_event(
+ openai_responses_types.ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item=openai_responses_types.
+ ResponseCodeInterpreterToolCallParam(
+ type="code_interpreter_call",
+ id=current_item_id,
+ code=previous_item.content[0].text,
+ container_id="auto",
+ # TODO: add outputs here
+ outputs=[],
+ status="completed",
+ ),
+ ))
+
+ async def empty_async_generator():
+ # A hack to trick Python to think this is a generator but in fact
+ # it immediately returns.
+ if False:
+ yield
+
+ final_response = await self.responses_full_generator(
+ request,
+ sampling_params,
+ empty_async_generator(),
+ context,
+ model_name,
+ tokenizer,
+ request_metadata,
+ created_time=created_time,
+ )
+ yield _send_event(
+ openai_responses_types.ResponseCompletedEvent(
+ type="response.completed",
+ sequence_number=-1,
+ response=final_response.model_dump(),
+ ))
diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py
index 4da2094147cea..c246274514dbf 100644
--- a/vllm/entrypoints/openai/serving_score.py
+++ b/vllm/entrypoints/openai/serving_score.py
@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
from fastapi import Request
+from vllm import envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
@@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
+# yapf conflicts with isort for this block
+# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
+ compress_token_type_ids,
get_score_prompt)
+# yapf: enable
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
@@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
+ self._validate_input(request, engine_prompt["prompt_token_ids"],
+ full_prompt)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
@@ -188,64 +195,27 @@ class ServingScores(OpenAIServing):
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
- if self.model_config.is_multimodal_model:
+ preprocess_async = make_async(self._preprocess_score,
+ executor=self._tokenizer_executor)
- preprocess_async = make_async(self._preprocess_score,
- executor=self._tokenizer_executor)
+ preprocessed_prompts = await asyncio.gather(
+ *(preprocess_async(request=request,
+ tokenizer=tokenizer,
+ tokenization_kwargs=tokenization_kwargs,
+ data_1=t1,
+ data_2=t2) for t1, t2 in input_pairs))
- preprocessed_prompts = await asyncio.gather(
- *(preprocess_async(request=request,
- tokenizer=tokenizer,
- tokenization_kwargs=tokenization_kwargs,
- data_1=t1,
- data_2=t2) for t1, t2 in input_pairs))
-
- for full_prompt, engine_prompt in preprocessed_prompts:
- request_prompts.append(full_prompt)
- engine_prompts.append(engine_prompt)
-
- else:
- tokenize_async = make_async(tokenizer.__call__,
- executor=self._tokenizer_executor)
- use_pad_token = self.model_config.use_pad_token
-
- if use_pad_token:
- # cross_encoder models defaults to using pad_token.
- tokenized_prompts = await asyncio.gather(*(
- tokenize_async(
- text=t1, # type: ignore[arg-type]
- text_pair=t2, # type: ignore[arg-type]
- **tokenization_kwargs) for t1, t2 in input_pairs))
- else:
- # `llm as reranker` models defaults to not using pad_token.
- tokenized_prompts = await asyncio.gather(*(
- tokenize_async(
- text=t1 + # type: ignore[operator]
- t2,
- **tokenization_kwargs) for t1, t2 in input_pairs))
-
- for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
- sep_token = tokenizer.sep_token if (tokenizer.sep_token
- and use_pad_token) else ''
- request_prompt = f"{t1}{sep_token}{t2}"
-
- input_ids = prompt_inputs["input_ids"]
- text_token_prompt = \
- self._validate_input(request, input_ids, request_prompt)
- engine_prompt = TokensPrompt(
- prompt_token_ids=text_token_prompt["prompt_token_ids"],
- token_type_ids=prompt_inputs.get("token_type_ids"))
-
- request_prompts.append(request_prompt)
- engine_prompts.append(engine_prompt)
+ for full_prompt, engine_prompt in preprocessed_prompts:
+ request_prompts.append(full_prompt)
+ engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
- pooling_params = request.to_pooling_params()
+ default_pooling_params = request.to_pooling_params()
try:
- pooling_params.verify("score", self.model_config)
+ default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
@@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item,
request_prompts[i],
- params=pooling_params,
+ params=default_pooling_params,
lora_request=lora_request)
+ if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
+ "token_type_ids", None)):
+ pooling_params = default_pooling_params.clone()
+ compressed = compress_token_type_ids(token_type_ids)
+ pooling_params.extra_kwargs = {
+ "compressed_token_type_ids": compressed
+ }
+ else:
+ pooling_params = (default_pooling_params)
+
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py
index f3f042355c9eb..642d6389539bc 100644
--- a/vllm/entrypoints/score_utils.py
+++ b/vllm/entrypoints/score_utils.py
@@ -184,15 +184,49 @@ def get_score_prompt(
model_config,
tokenizer,
)
+ from vllm.model_executor.model_loader import get_model_cls
- full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
-
- prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
+ model = get_model_cls(model_config)
+ if supports_score_template(model):
+ full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
+ prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
+ elif model_config.use_pad_token:
+ # cross_encoder models defaults to using pad_token.
+ prompt_inputs = tokenizer(text=prompt_1,
+ text_pair=prompt_2,
+ **tokenization_kwargs)
+ full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
+ else:
+ # `llm as reranker` models defaults to not using pad_token.
+ full_prompt = prompt_1 + prompt_2
+ prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
+ if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
+ engine_prompt["token_type_ids"] = token_type_ids
+
post_process_tokens(model_config, engine_prompt)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return full_prompt, engine_prompt
+
+
+def compress_token_type_ids(token_type_ids: list[int]) -> int:
+ """
+ Return position of the first 1 or the length of the list
+ if not found.
+ """
+ first_one = len(token_type_ids)
+ err_msg = "Token type ids are expected to be a sequence"\
+ " of zeros followed by a sequence of ones"
+ for i, type_id in enumerate(token_type_ids):
+ if type_id == 0 and first_one < i:
+ raise ValueError(err_msg)
+ elif type_id == 1 and first_one > i:
+ first_one = i
+ elif type_id > 1:
+ raise ValueError(err_msg)
+
+ return first_one
diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py
index 01ee77414f130..723cff91d44c7 100644
--- a/vllm/entrypoints/tool.py
+++ b/vllm/entrypoints/tool.py
@@ -2,7 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Optional
+
+from openai_harmony import Message
from vllm.logger import init_logger
@@ -70,7 +72,16 @@ class HarmonyPythonTool(Tool):
"gpt_oss is not installed, code interpreter is disabled")
return
- self.python_tool = PythonTool()
+ # NOTE (Chen): as of gpt-oss 0.0.2, there is a bug in _make_response
+ # and we do the following monkey patch to fix it.
+ class PatchedGptOssPythonTool(PythonTool):
+
+ def _make_response(self,
+ output: str,
+ channel: Optional[str] = None) -> Message:
+ return super()._make_response(output)
+
+ self.python_tool = PatchedGptOssPythonTool()
logger.info_once("Code interpreter tool initialized")
async def get_result(self, context: "ConversationContext") -> Any:
diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py
index 769c40e8cc588..2f28595f27c6a 100644
--- a/vllm/entrypoints/tool_server.py
+++ b/vllm/entrypoints/tool_server.py
@@ -2,15 +2,70 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Optional
-from openai_harmony import ToolNamespaceConfig
+from openai_harmony import ToolDescription, ToolNamespaceConfig
from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool
from vllm.logger import init_logger
logger = init_logger(__name__)
+if TYPE_CHECKING:
+ from mcp.types import ListToolsResult
+
+
+async def list_server_and_tools(server_url: str):
+ from mcp import ClientSession
+ from mcp.client.sse import sse_client
+
+ async with sse_client(url=server_url) as streams, ClientSession(
+ *streams) as session:
+ initialize_response = await session.initialize()
+ list_tools_response = await session.list_tools()
+ return initialize_response, list_tools_response
+
+
+def trim_schema(schema: dict) -> dict:
+ # Turn JSON Schema from MCP generated into Harmony's variant.
+ if "title" in schema:
+ del schema["title"]
+ if "default" in schema and schema["default"] is None:
+ del schema["default"]
+ if "anyOf" in schema:
+ # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
+ # into "type": ["type-1", "type-2"]
+ # if there's more than 1 types, also remove "null" type as Harmony will
+ # just ignore it
+ types = [
+ type_dict["type"] for type_dict in schema["anyOf"]
+ if type_dict["type"] != 'null'
+ ]
+ schema["type"] = types
+ del schema["anyOf"]
+ if "properties" in schema:
+ schema["properties"] = {
+ k: trim_schema(v)
+ for k, v in schema["properties"].items()
+ }
+ return schema
+
+
+def post_process_tools_description(
+ list_tools_result: "ListToolsResult") -> "ListToolsResult":
+ # Adapt the MCP tool result for Harmony
+ for tool in list_tools_result.tools:
+ tool.inputSchema = trim_schema(tool.inputSchema)
+
+ # Some tools schema don't need to be part of the prompt (e.g. simple text
+ # in text out for Python)
+ list_tools_result.tools = [
+ tool for tool in list_tools_result.tools
+ if getattr(tool.annotations, "include_in_prompt", True)
+ ]
+
+ return list_tools_result
+
class ToolServer(ABC):
@@ -38,6 +93,67 @@ class ToolServer(ABC):
...
+class MCPToolServer(ToolServer):
+
+ def __init__(self):
+ try:
+ import mcp # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "mcp is not installed. Please run `pip install mcp` to use "
+ "MCPToolServer.") from None
+ self.harmony_tool_descriptions = {}
+
+ async def add_tool_server(self, server_url: str):
+ tool_urls = server_url.split(",")
+ self.harmony_tool_descriptions = {}
+ self.urls: dict[str, str] = {}
+ for url in tool_urls:
+ url = f"http://{url}/sse"
+ initialize_response, list_tools_response = (
+ await list_server_and_tools(url))
+
+ list_tools_response = post_process_tools_description(
+ list_tools_response)
+
+ tool_from_mcp = ToolNamespaceConfig(
+ name=initialize_response.serverInfo.name,
+ description=initialize_response.instructions,
+ tools=[
+ ToolDescription.new(name=tool.name,
+ description=tool.description,
+ parameters=tool.inputSchema)
+ for tool in list_tools_response.tools
+ ])
+ self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
+ if tool_from_mcp.name not in self.urls:
+ self.urls[tool_from_mcp.name] = url
+ else:
+ logger.warning(
+ "Tool %s already exists. Ignoring duplicate tool server %s",
+ tool_from_mcp.name, url)
+ logger.info("MCPToolServer initialized with tools: %s",
+ list(self.harmony_tool_descriptions.keys()))
+
+ def has_tool(self, tool_name: str):
+ return tool_name in self.harmony_tool_descriptions
+
+ def get_tool_description(self, tool_name: str):
+ return self.harmony_tool_descriptions.get(tool_name)
+
+ @asynccontextmanager
+ async def new_session(self, tool_name: str):
+ from mcp import ClientSession
+ from mcp.client.sse import sse_client
+ url = self.urls.get(tool_name)
+ if not url:
+ raise KeyError(f"Tool '{tool_name}' is not supported")
+ async with sse_client(url=url) as streams, ClientSession(
+ *streams) as session:
+ await session.initialize()
+ yield session
+
+
class DemoToolServer(ToolServer):
def __init__(self):
@@ -67,4 +183,6 @@ class DemoToolServer(ToolServer):
@asynccontextmanager
async def new_session(self, tool_name: str):
+ if tool_name not in self.tools:
+ raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name]
diff --git a/vllm/envs.py b/vllm/envs.py
index 8b12a7ee2b988..931edcfa7f1e8 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
+ VLLM_DOCKER_BUILD_CONTEXT: bool = False
VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
@@ -126,9 +127,11 @@ if TYPE_CHECKING:
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
+ VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
+ VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -233,8 +236,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vllm will use precompiled binaries (*.so)
"VLLM_USE_PRECOMPILED":
- lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool(
- os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")),
+ lambda: os.environ.get("VLLM_USE_PRECOMPILED", "").strip().lower() in
+ ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")),
+
+ # Used to mark that setup.py is running in a Docker build context,
+ # in order to force the use of precompiled binaries.
+ "VLLM_DOCKER_BUILD_CONTEXT":
+ lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "").strip().lower() in
+ ("1", "true"),
# Whether to force using nightly wheel in python build.
# This is used for testing the nightly wheel in python build.
@@ -917,6 +926,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
+ # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
+ # E8M0 is faster on B200 but may reduce accuracy.
+ "VLLM_USE_DEEP_GEMM_E8M0":
+ lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
@@ -982,6 +995,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
+ # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
+ # require compute capability 10.0 or above.
+ # Available options:
+ # - "throughput": [default]
+ # Uses CUTLASS kernels optimized for high-throughput batch inference.
+ # - "latency":
+ # Uses TensorRT-LLM kernels optimized for low-latency inference.
+ # To set this backend, define the environment variable:
+ # export VLLM_FLASHINFER_MOE_BACKEND=latency.
+ # If not set, defaults to "throughput".
+ "VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
+ "VLLM_FLASHINFER_MOE_BACKEND", "throughput"
+ ),
+
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index dd55b19feeaf6..4686ba24e65f3 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -26,10 +26,26 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
+def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
+ max_num_tokens: int,
+ chunk_idx: int) -> list[int]:
+ dp_size = len(num_tokens_across_dp_cpu)
+
+ local_size = [-1] * dp_size
+ for i in range(dp_size):
+ dp_tokens = num_tokens_across_dp_cpu[i]
+ local_size[i] = min(max_num_tokens,
+ dp_tokens - (max_num_tokens * chunk_idx))
+ if local_size[i] <= 0:
+ local_size[i] = 1 # ensure lockstep even if done
+ return local_size
+
+
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
+ local_sizes: Optional[list[int]] = None
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
@@ -78,6 +94,48 @@ class DPMetadata:
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
+ @contextmanager
+ def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
+ """
+ Context manager to compute and temporarily set the per-rank local token
+ sizes for a specific chunk during chunked forward execution.
+
+ This is necessary to ensure each DP (data parallel) rank processes its
+ designated portion of tokens in lockstep with others, even when the
+ token counts are uneven or some ranks have completed their input early.
+
+ For chunked execution, we break up the total tokens on each rank into
+ multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
+ `chunk_idx`, this context manager sets `self.local_sizes` to the number
+ of tokens to process in that chunk on each rank.
+
+ It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
+ number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
+ to determine the chunk-wise split.
+
+ `self.local_sizes` is only valid inside the context.
+
+ Args:
+ max_chunk_size_per_rank: The max number of tokens each rank is
+ allowed to process in this chunk.
+ chunk_idx: The index of the chunk to compute sizes for.
+ """
+ cu_sizes = self.cu_tokens_across_dp_cpu
+ num_tokens_across_dp_cpu = [
+ (cu_sizes[i] -
+ cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
+ for i in range(len(cu_sizes))
+ ]
+ self.local_sizes = _compute_chunked_local_num_tokens(
+ num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
+ try:
+ yield self.local_sizes
+ finally:
+ self.local_sizes = None
+
+ def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
+ return self.local_sizes
+
@dataclass
class ForwardContext:
diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py
index 37bf2b7a44366..aef7841e71b71 100644
--- a/vllm/inputs/__init__.py
+++ b/vllm/inputs/__init__.py
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
- ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
- SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
- TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
+from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
+ EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
+ ProcessorInputs, PromptType, SingletonInputs,
+ SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
+ build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
@@ -24,6 +25,7 @@ __all__ = [
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"EmbedsInputs",
+ "EmbedsPrompt",
"token_inputs",
"embeds_inputs",
"DecoderOnlyInputs",
diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py
index 6331a70b469aa..dc3236508348f 100644
--- a/vllm/inputs/registry.py
+++ b/vllm/inputs/registry.py
@@ -8,10 +8,10 @@ import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar
-from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils import get_allowed_kwarg_only_overrides
+from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
from vllm.config import ModelConfig
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index 3ccddb52998b2..c48a0137c3060 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
- is_blackwell_deep_gemm_used)
+ is_blackwell_deep_gemm_e8m0_used)
logger = init_logger(__name__)
@@ -176,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm(
eps,
fp8_min,
fp8_max,
- is_blackwell_deep_gemm_used(),
+ is_blackwell_deep_gemm_e8m0_used(),
BLOCK=group_size,
NUM_STAGES=8,
num_warps=1,
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 9e4ee5a3d7b95..31ea826f1f97a 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -192,7 +192,8 @@ class FusedMoEParallelConfig:
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
- and has_flashinfer_cutlass_fused_moe())
+ and has_flashinfer_cutlass_fused_moe()
+ and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod
def make(tp_size_: int, dp_size_: int,
@@ -323,6 +324,8 @@ class FusedMoEConfig:
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
+ has_bias: bool = False
+
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
@@ -412,7 +415,8 @@ class FusedMoEConfig:
in_dtype: torch.dtype,
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config: Optional[Union[FusedMoEQuantConfig,
- QuantizationConfig]] = None
+ QuantizationConfig]] = None,
+ has_bias: bool = False,
) -> "FusedMoEConfig":
_quant_config: Optional[FusedMoEQuantConfig] = None
@@ -481,4 +485,5 @@ class FusedMoEConfig:
in_dtype=in_dtype,
quant_config=_quant_config,
max_num_tokens=max_num_tokens,
+ has_bias=has_bias,
)
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index ba7105c83a92f..9b8175f42a9d2 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -237,18 +237,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w1_scale is not None
assert w2_scale is not None
- if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
- # DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
- # to happen during actual model-inference. The
- # `warmup_deepgemm_kernels` function is a `run_once` decorated
- # function that executes during the model profile run. This warmup
- # should create all the required JITs for the current model.
- warmup_deepgemm_gg_contiguous_kernels(w1,
- w2,
- w1_scale,
- w2_scale,
- num_topk=topk_ids.size(1))
-
a1q = hidden_states
_, N, K = w1.size()
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
index 3e79a1a8c24b2..4e3e15a35ada2 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
@@ -170,8 +170,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
- assert not apply_router_weight_on_input
-
quant_scales = [
a1_gscale,
w1_scale.view(torch.int32),
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
index 02e1d1f1fd02e..36aca8cf74b6d 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
@@ -4,7 +4,6 @@ from typing import Any, Optional
import torch
-import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
@@ -14,20 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
-def get_local_sizes(local_tokens):
- cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
- sizes = [cu_sizes[0].item()]
- for i in range(1, len(cu_sizes)):
- sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
- max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
- sizes_chunked = [max_num_tokens] * len(sizes)
- if local_tokens < max_num_tokens:
- # When the number of local tokens is less than max_num_tokens, all other
- # ranks will also have fewer than max_num_tokens. The remaining tokens
- # are accounted for as residual.
- sizes_chunked = [x % max_num_tokens for x in sizes]
-
- return sizes_chunked
+def get_local_sizes():
+ return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@@ -73,7 +60,12 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
- assert not apply_router_weight_on_input
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ # TODO: this only works for topK=1, will need to update for topK>1
+ assert topk == 1, \
+ "apply_router_weight_on_input is only implemented for topk=1"
+ a1.mul_(topk_weights.to(a1.dtype))
(a1_gscale, use_dp, local_tokens) = extract_required_args(
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
@@ -90,7 +82,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
dim=0,
- sizes=get_local_sizes(local_tokens))
+ sizes=get_local_sizes())
a1_m, a1_n = a1q.shape
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
@@ -107,8 +99,5 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
['use_dp', 'local_tokens'])
if use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
- fused_expert_output,
- dim=0,
- sizes=get_local_sizes(local_tokens),
- )
+ fused_expert_output, dim=0, sizes=get_local_sizes())
output.copy_(fused_expert_output)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 597af08c3c9fa..ad094c37f9478 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -4,6 +4,9 @@
import functools
import json
import os
+# torch.compile needs typing.List. It will fail torch.library.infer_schema
+# otherwise
+from typing import List # noqa: UP035
from typing import Any, Callable, Optional
import torch
@@ -37,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
-from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
+from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@@ -272,6 +275,7 @@ def fused_moe_kernel(
a_ptr,
b_ptr,
c_ptr,
+ b_bias_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
@@ -299,6 +303,8 @@ def fused_moe_kernel(
stride_bse,
stride_bsk,
stride_bsn,
+ stride_bbe, # bias expert stride
+ stride_bbn, # bias N stride
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
@@ -314,6 +320,7 @@ def fused_moe_kernel(
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
@@ -411,7 +418,10 @@ def fused_moe_kernel(
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
-
+ if HAS_BIAS:
+ # bias shape: [num_experts, N]
+ bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
+ bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@@ -453,7 +463,8 @@ def fused_moe_kernel(
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
-
+ if HAS_BIAS:
+ accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
@@ -468,6 +479,7 @@ def fused_moe_kernel(
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
+
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -496,7 +508,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
- block_shape: Optional[list[int]] = None) -> None:
+ block_shape: Optional[list[int]] = None,
+ B_bias: Optional[torch.Tensor] = None) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -528,7 +541,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.size(1), META['BLOCK_SIZE_N']), )
-
+ HAS_BIAS = B_bias is not None
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3
@@ -608,6 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A,
B,
C,
+ B_bias,
A_scale,
B_scale,
topk_weights,
@@ -635,6 +649,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1)
if B_scale is not None and B_scale.ndim >= 2 else 0,
+ B_bias.stride(0) if B_bias is not None else 0,
+ B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
@@ -644,6 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
+ HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
@@ -998,39 +1015,7 @@ def get_config_dtype_str(
return None
-def inplace_fused_experts(hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- activation: str = "silu",
- is_act_and_mul: bool = True,
- apply_router_weight_on_input: bool = False,
- use_fp8_w8a8: bool = False,
- use_int8_w8a8: bool = False,
- use_int8_w8a16: bool = False,
- use_int4_w4a16: bool = False,
- use_mxfp4_w4a4: bool = False,
- per_channel_quant: bool = False,
- global_num_experts: int = -1,
- expert_map: Optional[torch.Tensor] = None,
- w1_scale: Optional[torch.Tensor] = None,
- w2_scale: Optional[torch.Tensor] = None,
- w1_zp: Optional[torch.Tensor] = None,
- w2_zp: Optional[torch.Tensor] = None,
- a1_scale: Optional[torch.Tensor] = None,
- a2_scale: Optional[torch.Tensor] = None,
- block_shape: Optional[list[int]] = None) -> None:
- fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
- activation, is_act_and_mul,
- apply_router_weight_on_input, use_fp8_w8a8,
- use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
- use_mxfp4_w4a4, per_channel_quant, global_num_experts,
- expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
- a2_scale, block_shape)
-
-
-def inplace_fused_experts_fake(
+def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -1053,7 +1038,43 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
- block_shape: Optional[list[int]] = None) -> None:
+ block_shape: Optional[List[int]] = None, #noqa: UP006
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None) -> None:
+ fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
+ activation, is_act_and_mul,
+ apply_router_weight_on_input, use_fp8_w8a8,
+ use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
+ use_mxfp4_w4a4, per_channel_quant, global_num_experts,
+ expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
+ a2_scale, block_shape, w1_bias, w2_bias)
+
+
+def inplace_fused_experts_fake(hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str = "silu",
+ is_act_and_mul: bool = True,
+ apply_router_weight_on_input: bool = False,
+ use_fp8_w8a8: bool = False,
+ use_int8_w8a8: bool = False,
+ use_int8_w8a16: bool = False,
+ use_int4_w4a16: bool = False,
+ use_mxfp4_w4a4: bool = False,
+ per_channel_quant: bool = False,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ w1_zp: Optional[torch.Tensor] = None,
+ w2_zp: Optional[torch.Tensor] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ block_shape: Optional[list[int]] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None) -> None:
pass
@@ -1082,7 +1103,7 @@ def flashinfer_fused_moe_blockscale_fp8(
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
- block_shape: list[int],
+ block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
@@ -1242,35 +1263,38 @@ direct_register_custom_op(
def outplace_fused_experts(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- activation: str = "silu",
- is_act_and_mul: bool = True,
- apply_router_weight_on_input: bool = False,
- use_fp8_w8a8: bool = False,
- use_int8_w8a8: bool = False,
- use_int8_w8a16: bool = False,
- use_int4_w4a16: bool = False,
- use_mxfp4_w4a4: bool = False,
- per_channel_quant: bool = False,
- global_num_experts: int = -1,
- expert_map: Optional[torch.Tensor] = None,
- w1_scale: Optional[torch.Tensor] = None,
- w2_scale: Optional[torch.Tensor] = None,
- w1_zp: Optional[torch.Tensor] = None,
- w2_zp: Optional[torch.Tensor] = None,
- a1_scale: Optional[torch.Tensor] = None,
- a2_scale: Optional[torch.Tensor] = None,
- block_shape: Optional[list[int]] = None) -> torch.Tensor:
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str = "silu",
+ is_act_and_mul: bool = True,
+ apply_router_weight_on_input: bool = False,
+ use_fp8_w8a8: bool = False,
+ use_int8_w8a8: bool = False,
+ use_int8_w8a16: bool = False,
+ use_int4_w4a16: bool = False,
+ use_mxfp4_w4a4: bool = False,
+ per_channel_quant: bool = False,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ w1_zp: Optional[torch.Tensor] = None,
+ w2_zp: Optional[torch.Tensor] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ block_shape: Optional[List[int]] = None, #noqa: UP006
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
- w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
+ w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
def outplace_fused_experts_fake(
@@ -1295,7 +1319,9 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
- block_shape: Optional[list[int]] = None) -> torch.Tensor:
+ block_shape: Optional[list[int]] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1327,41 +1353,42 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
-def fused_experts(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- inplace: bool = False,
- activation: str = "silu",
- is_act_and_mul: bool = True,
- apply_router_weight_on_input: bool = False,
- use_fp8_w8a8: bool = False,
- use_int8_w8a8: bool = False,
- use_int8_w8a16: bool = False,
- use_int4_w4a16: bool = False,
- use_mxfp4_w4a4: bool = False,
- per_channel_quant: bool = False,
- global_num_experts: int = -1,
- expert_map: Optional[torch.Tensor] = None,
- w1_scale: Optional[torch.Tensor] = None,
- w2_scale: Optional[torch.Tensor] = None,
- w1_zp: Optional[torch.Tensor] = None,
- w2_zp: Optional[torch.Tensor] = None,
- a1_scale: Optional[torch.Tensor] = None,
- a2_scale: Optional[torch.Tensor] = None,
- block_shape: Optional[list[int]] = None,
- allow_deep_gemm: bool = False,
- allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
+def fused_experts(hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ inplace: bool = False,
+ activation: str = "silu",
+ is_act_and_mul: bool = True,
+ apply_router_weight_on_input: bool = False,
+ use_fp8_w8a8: bool = False,
+ use_int8_w8a8: bool = False,
+ use_int8_w8a16: bool = False,
+ use_int4_w4a16: bool = False,
+ use_mxfp4_w4a4: bool = False,
+ per_channel_quant: bool = False,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ w1_zp: Optional[torch.Tensor] = None,
+ w2_zp: Optional[torch.Tensor] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ block_shape: Optional[list[int]] = None,
+ allow_deep_gemm: bool = False,
+ allow_cutlass_block_scaled_grouped_gemm: bool = False,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
- should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm(
- hidden_states, w1, w2)
+ should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
+ ) or _valid_deep_gemm(hidden_states, w1, w2)
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
@@ -1418,7 +1445,10 @@ def fused_experts(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
- block_shape=block_shape)
+ block_shape=block_shape,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ )
def fused_experts_impl(
@@ -1446,6 +1476,8 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
@@ -1586,7 +1618,19 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
- block_shape=block_shape)
+ block_shape=block_shape,
+ B_bias=w1_bias)
+
+ # TODO fused kernel
+ def swiglu_oai(gate_up):
+ alpha = 1.702
+ limit = 7.0
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+ glu = gate * torch.sigmoid(gate * alpha)
+ gated_output = (up + 1) * glu
+ return gated_output
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
@@ -1600,6 +1644,8 @@ def fused_experts_impl(
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
+ elif activation == "swiglu_oai":
+ intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")
@@ -1630,7 +1676,8 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
- block_shape=block_shape)
+ block_shape=block_shape,
+ B_bias=w2_bias)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
@@ -1667,6 +1714,8 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1761,7 +1810,9 @@ def fused_moe(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
- block_shape=block_shape)
+ block_shape=block_shape,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -1932,7 +1983,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
- block_shape=self.block_shape)
+ block_shape=self.block_shape,
+ B_bias=None # TODO support B_bias
+ )
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
@@ -1943,26 +1996,29 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2, a2_scale, self.quant_dtype,
self.per_act_token_quant, self.block_shape)
- invoke_fused_moe_kernel(qintermediate_cache2,
- w2,
- intermediate_cache3,
- a2q_scale,
- w2_scale,
- w2_zp,
- topk_weights,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- not apply_router_weight_on_input,
- 1,
- config,
- compute_type=compute_type,
- use_fp8_w8a8=self.use_fp8_w8a8,
- use_int8_w8a8=self.use_int8_w8a8,
- use_int8_w8a16=self.use_int8_w8a16,
- use_int4_w4a16=self.use_int4_w4a16,
- per_channel_quant=self.per_act_token_quant,
- block_shape=self.block_shape)
+ invoke_fused_moe_kernel(
+ qintermediate_cache2,
+ w2,
+ intermediate_cache3,
+ a2q_scale,
+ w2_scale,
+ w2_zp,
+ topk_weights,
+ sorted_token_ids,
+ expert_ids,
+ num_tokens_post_padded,
+ not apply_router_weight_on_input,
+ 1,
+ config,
+ compute_type=compute_type,
+ use_fp8_w8a8=self.use_fp8_w8a8,
+ use_int8_w8a8=self.use_int8_w8a8,
+ use_int8_w8a16=self.use_int8_w8a16,
+ use_int4_w4a16=self.use_int4_w4a16,
+ per_channel_quant=self.per_act_token_quant,
+ block_shape=self.block_shape,
+ B_bias=None # TODO support B_bias
+ )
ops.moe_sum(intermediate_cache3, output)
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
new file mode 100644
index 0000000000000..6b5284dc6c96c
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
@@ -0,0 +1,248 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import TYPE_CHECKING, Any, Optional
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
+ TopKWeightAndReduceDelegate)
+from vllm.model_executor.layers.fused_moe.utils import extract_required_args
+from vllm.utils import has_triton_kernels
+
+logger = init_logger(__name__)
+
+if has_triton_kernels():
+ try:
+ import triton_kernels.swiglu
+ from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
+ matmul_ogs)
+ from triton_kernels.routing import routing
+ except ModuleNotFoundError:
+ logger.error(
+ "Failed to import Triton kernels. Please make sure your triton "
+ "version is compatible.")
+
+if TYPE_CHECKING:
+ from triton_kernels.matmul_ogs import PrecisionConfig
+
+
+def triton_kernel_moe_forward(
+ hidden_states: torch.Tensor,
+ w1, # Tensor or triton_kernels.Tensor
+ w2, # Tensor or triton_kernels.Tensor
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+ activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
+ use_fp8_w8a8: bool = False,
+ per_channel_quant: bool = False,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ w1_precision: Optional["PrecisionConfig"] = None,
+ w2_precision: Optional["PrecisionConfig"] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ block_shape: Optional[list[int]] = None,
+) -> torch.Tensor:
+
+ routing_data, gather_idx, scatter_idx = routing(gating_output,
+ topk,
+ sm_first=not renormalize)
+
+ return triton_kernel_fused_experts(
+ None,
+ hidden_states,
+ w1,
+ w2,
+ routing_data,
+ gather_idx,
+ scatter_idx,
+ activation=activation,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ use_fp8_w8a8=use_fp8_w8a8,
+ per_channel_quant=per_channel_quant,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ w1_precision=w1_precision,
+ w2_precision=w2_precision,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ block_shape=block_shape)
+
+
+# This is a triton implementation of the fused_experts function
+def triton_kernel_fused_experts(
+ output_tensor: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1, # Tensor or triton_kernels.Tensor
+ w2, # Tensor or triton_kernels.Tensor
+ routing_data, # RoutingData
+ gather_indx, # GatherIndx
+ scatter_indx, # ScatterIndx
+ activation: str = "silu",
+ swiglu_alpha: float = 1.702,
+ swiglu_limit: float = 7.0,
+ apply_router_weight_on_input: bool = False,
+ use_fp8_w8a8: bool = False,
+ per_channel_quant: bool = False,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
+ w1_precision: Optional["PrecisionConfig"] = None,
+ w2_precision: Optional["PrecisionConfig"] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
+ block_shape: Optional[list[int]] = None,
+) -> torch.Tensor:
+
+ # type check, uint8 means mxfp4
+ assert hidden_states.dtype == torch.bfloat16
+ assert w1_bias is None or w1_bias.dtype == torch.float32
+ assert w2_bias is None or w2_bias.dtype == torch.float32
+
+ # Shape check, only check non-mxfp4
+ assert hidden_states.shape[-1] == w1.shape[-2]
+ assert w2.shape[-1] == w1.shape[1]
+
+ E, _, N = w1.shape
+
+ if global_num_experts == -1:
+ global_num_experts = E
+
+ act = FusedActivation(
+ FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
+ (swiglu_alpha, swiglu_limit), 2)
+ gammas = routing_data.gate_scal if routing_data else None
+
+ intermediate_cache1 = matmul_ogs(
+ hidden_states,
+ w1,
+ w1_bias,
+ routing_data,
+ gather_indx=gather_indx,
+ precision_config=w1_precision,
+ gammas=gammas if apply_router_weight_on_input else None,
+ fused_activation=act)
+
+ intermediate_cache3 = matmul_ogs(
+ intermediate_cache1,
+ w2,
+ w2_bias,
+ routing_data,
+ scatter_indx=scatter_indx,
+ precision_config=w2_precision,
+ gammas=None if apply_router_weight_on_input else gammas,
+ y=output_tensor,
+ )
+ return intermediate_cache3
+
+
+class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+ def __init__(
+ self,
+ quant_config,
+ max_num_tokens: int,
+ num_dispatchers: int,
+ w1_precision: "PrecisionConfig",
+ w2_precision: "PrecisionConfig",
+ ):
+ super().__init__(quant_config)
+ self.max_num_tokens = max_num_tokens
+ self.num_dispatchers = num_dispatchers
+ self.w1_precision = w1_precision
+ self.w2_precision = w2_precision
+
+ @property
+ def activation_formats(
+ self
+ ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ return (mk.FusedMoEActivationFormat.BatchedExperts,
+ mk.FusedMoEActivationFormat.BatchedExperts)
+
+ def supports_chunking(self) -> bool:
+ return False
+
+ def supports_expert_map(self) -> bool:
+ return False
+
+ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
+ # Let PrepareAndFinalize::finalize() decide the impl.
+ return TopKWeightAndReduceDelegate()
+
+ def workspace_shapes(
+ self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
+ topk: int, global_num_experts: int, local_num_experts: int,
+ expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
+ # workspace are allocated inside the kernel
+ assert a.dim() == 2
+ num_dp = self.num_dispatchers
+ num_experts = local_num_experts
+ max_num_tokens = self.max_num_tokens
+ workspace2 = (0, 0, 0)
+ output = (num_experts, max_num_tokens * num_dp, N)
+ return (output, workspace2, output, a.dtype)
+
+ def apply(
+ self,
+ output: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str,
+ global_num_experts: int,
+ expert_map: Optional[torch.Tensor],
+ w1_scale: Optional[torch.Tensor],
+ w2_scale: Optional[torch.Tensor],
+ w1_zp: Optional[torch.Tensor],
+ w2_zp: Optional[torch.Tensor],
+ a1q_scale: Optional[torch.Tensor],
+ a2_scale: Optional[torch.Tensor],
+ workspace13: torch.Tensor,
+ workspace2: torch.Tensor,
+ expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
+ apply_router_weight_on_input: bool,
+ extra_expert_args: Optional[dict[str, Any]],
+ ):
+ w1_bias, w2_bias = (extract_required_args(extra_expert_args,
+ ["w1_bias", "w2_bias"]))
+
+ return triton_kernel_fused_experts(
+ output,
+ hidden_states,
+ w1,
+ w2,
+ None,
+ None,
+ None,
+ activation=activation,
+ apply_router_weight_on_input=False,
+ use_fp8_w8a8=False,
+ per_channel_quant=False,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ w1_precision=self.w1_precision,
+ w2_precision=self.w2_precision,
+ a1_scale=a1q_scale,
+ a2_scale=a2_scale)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 76cedb3ed3487..d5a89655e36d6 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -36,7 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
- round_up)
+ has_triton_kernels, is_torch_equal_or_newer, round_up)
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike():
@@ -255,7 +255,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
-
+ self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -291,7 +291,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
-
+ if self.has_bias:
+ w13_bias = torch.nn.Parameter(torch.zeros(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w13_bias", w13_bias)
+ set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
@@ -301,6 +308,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
+ if self.has_bias:
+ w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
+ hidden_size,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w2_bias", w2_bias)
+ set_weight_attrs(w2_bias, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
@@ -465,6 +479,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
+ w1_bias=layer.w13_bias if self.has_bias else None,
+ w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
@@ -702,6 +718,7 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
+ has_bias: bool = False,
):
super().__init__()
if params_dtype is None:
@@ -723,10 +740,17 @@ class FusedMoE(torch.nn.Module):
self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works
- if quant_config and quant_config.get_name() == "mxfp4" and (
- envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
- or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
- hidden_size = round_up(hidden_size, 256)
+ if quant_config and quant_config.get_name() == "mxfp4":
+ if not is_torch_equal_or_newer("2.8.0"):
+ raise RuntimeError("Mxfp4 on hopper requires torch >= 2.8.0")
+ if current_platform.is_device_capability(
+ 90) and not has_triton_kernels():
+ raise NotImplementedError(
+ "Triton kernels must be installed for mxfp4 on hopper")
+ if (current_platform.is_rocm()
+ or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
+ or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
+ hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
@@ -786,16 +810,15 @@ class FusedMoE(torch.nn.Module):
# since model_config is not set in the pytest test.
model_dtype = params_dtype
- moe = FusedMoEConfig.make(
- num_experts=self.global_num_experts,
- experts_per_token=top_k,
- hidden_dim=hidden_size,
- num_local_experts=self.local_num_experts,
- moe_parallel_config=self.moe_parallel_config,
- in_dtype=model_dtype,
- max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
- quant_config=quant_config,
- )
+ moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
+ experts_per_token=top_k,
+ hidden_dim=hidden_size,
+ num_local_experts=self.local_num_experts,
+ moe_parallel_config=self.moe_parallel_config,
+ in_dtype=model_dtype,
+ max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
+ quant_config=quant_config,
+ has_bias=has_bias)
self.moe_config = moe
self.quant_config = quant_config
@@ -1570,18 +1593,19 @@ class FusedMoE(torch.nn.Module):
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
num_tokens = full_hidden_states.size(0)
- for chunk_start_ in range(0, max_tokens_across_dp,
- moe_dp_chunk_size_per_rank):
+ for chunk_idx, chunk_start_ in enumerate(
+ range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
-
- process_chunk(chunk_start,
- chunk_end,
- skip_result_store=chunk_start_ >= num_tokens)
+ with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
+ chunk_idx):
+ process_chunk(chunk_start,
+ chunk_end,
+ skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states
diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
index c67f7e808301a..9d0ff2e06190e 100644
--- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
-from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
+from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
- if self.allow_deep_gemm and (is_blackwell_deep_gemm_used()
+ if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
or _valid_deep_gemm_shape(M, N, K)):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
@@ -133,7 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
extra_expert_args: Optional[dict[str, Any]]):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
- or is_blackwell_deep_gemm_used()))
+ or is_blackwell_deep_gemm_e8m0_used()))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None
diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py
index 978086d1909d1..8ffc700ca5cde 100644
--- a/vllm/model_executor/layers/lightning_attn.py
+++ b/vllm/model_executor/layers/lightning_attn.py
@@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
pid_d = tl.program_id(2) # dimension block index
# Load slot index for the current batch
- slot_id = tl.load(slot_idx + pid_b)
+ slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
# Skip if slot_id is -1 (padding)
if slot_id == -1:
diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py
index 42c815b08f042..ad14017912381 100644
--- a/vllm/model_executor/layers/mamba/mamba_utils.py
+++ b/vllm/model_executor/layers/mamba/mamba_utils.py
@@ -5,6 +5,17 @@ from vllm.distributed import divide
class MambaStateShapeCalculator:
+ @classmethod
+ def linear_attention_state_shape(
+ cls,
+ num_heads: int,
+ tp_size: int,
+ head_dim: int,
+ ) -> tuple[tuple[int, int, int], ...]:
+
+ state_shape = (num_heads // tp_size, head_dim, head_dim)
+ return (state_shape, )
+
@classmethod
def mamba1_state_shape(
cls,
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
index fc2b3b25fd0a8..365139e237c66 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
@@ -290,10 +290,8 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
dA_cs_m_boundary = tl.load(
- dA_cumsum_ptr +
- (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
- mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
- and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
+ dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
+ mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py
index ad2853a3d8a8b..fd74cb837290b 100644
--- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py
+++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py
@@ -21,6 +21,10 @@ from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
+def is_int_pow_2(n):
+ return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
+
+
def _mamba_chunk_scan_combined_fwd(x,
dt,
A,
@@ -38,6 +42,7 @@ def _mamba_chunk_scan_combined_fwd(x,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
out=None):
+ assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py
index 0f2e58eb9b5d9..e2162e5cbf956 100644
--- a/vllm/model_executor/layers/pooler.py
+++ b/vllm/model_executor/layers/pooler.py
@@ -44,15 +44,14 @@ class ResolvedPoolingConfig:
task: PoolingTask
@classmethod
- def from_config_with_defaults(
+ def from_config(
cls,
task: PoolingTask,
pooler_config: PoolerConfig,
- pooling_type: PoolingType,
) -> "ResolvedPoolingConfig":
+ assert pooler_config.pooling_type is not None
return cls(task=task,
- pooling_type=PoolingType[pooler_config.pooling_type]
- if pooler_config.pooling_type is not None else pooling_type)
+ pooling_type=PoolingType[pooler_config.pooling_type])
@dataclass(frozen=True)
@@ -68,32 +67,20 @@ class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
- def for_encode(
- pooler_config: PoolerConfig,
- *,
- default_pooling_type: PoolingType = PoolingType.ALL,
- ):
- resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
- task="encode",
- pooler_config=pooler_config,
- pooling_type=default_pooling_type,
- )
-
- if resolved_config.pooling_type == PoolingType.STEP:
+ def for_encode(pooler_config: PoolerConfig):
+ if pooler_config.pooling_type == "STEP":
return StepPooler()
+ resolved_config = ResolvedPoolingConfig(task="encode",
+ pooling_type=PoolingType.ALL)
+
return SimplePooler.from_config(resolved_config)
@staticmethod
- def for_embed(
- pooler_config: PoolerConfig,
- *,
- default_pooling_type: PoolingType = PoolingType.LAST,
- ):
- resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
+ def for_embed(pooler_config: PoolerConfig):
+ resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
- pooling_type=default_pooling_type,
)
return SimplePooler.from_config(resolved_config)
@@ -102,13 +89,10 @@ class Pooler(nn.Module, ABC):
def for_classify(
pooler_config: PoolerConfig,
classifier: Optional[ClassifierFn],
- *,
- default_pooling_type: PoolingType = PoolingType.LAST,
):
- resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
+ resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
- pooling_type=default_pooling_type,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 0fdded0b5a7fc..6cf02658a94c5 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
- FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
+ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
+ UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
@@ -141,6 +142,9 @@ class AWQMarlinConfig(QuantizationConfig):
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
+ if is_layer_skipped_awq(
+ prefix, getattr(self, "modules_to_not_convert", [])):
+ return UnquantizedFusedMoEMethod(layer.moe_config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
@@ -520,4 +524,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
- workspace=layer.workspace)
+ workspace=layer.workspace)
\ No newline at end of file
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index 09d8890888fa8..c04f7c39a5f5d 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
- self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
+ self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.fused_experts = None # type: ignore[assignment]
@@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
requires_grad=False)
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
- if self.allow_flashinfer_cutlass:
+ if self.allow_flashinfer:
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
layer.w13_weight_scale.data,
dim=-2)
@@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config):
- if not self.allow_flashinfer_cutlass:
+ if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
@@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
- return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
- logger)
+ return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
def apply(
self,
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 8b6ed154bdbe4..9577fa025b707 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -45,7 +45,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
-from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
+from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
+ is_deep_gemm_supported)
from vllm.utils.flashinfer import has_flashinfer_moe
if TYPE_CHECKING:
@@ -415,10 +416,10 @@ class Fp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin.
del layer.input_scale
- # On B200, DeepGemm only support E8M0 scale, which means we need to
+ # On B200, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
- if is_blackwell_deep_gemm_used():
+ if is_blackwell_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
@@ -505,15 +506,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"DeepGemm kernels")
- elif (current_platform.is_cuda()
- and current_platform.is_device_capability(90)):
+ elif (is_deep_gemm_supported()):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
- elif (current_platform.is_cuda()
- and is_blackwell_deep_gemm_used()):
- logger.info_once("Using DeepGemm SM100 kernels for "
- "Fp8MoEMethod.")
- self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
@@ -725,7 +720,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
- if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
+ if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
# Lazy import to avoid CUDA initialization problems.
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
@@ -851,7 +846,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
- if is_blackwell_deep_gemm_used():
+ if is_blackwell_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 0334a2824512d..bed502226716c 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from enum import Enum
from typing import Any, Callable, Optional, Union
import torch
@@ -36,6 +37,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.scalar_type import scalar_types
+from vllm.utils import next_power_of_2
from vllm.utils.flashinfer import has_flashinfer_moe
logger = init_logger(__name__)
@@ -44,6 +46,11 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
+class FlashinferMoeBackend(Enum):
+ TENSORRT_LLM = "TensorRT-LLM"
+ CUTLASS = "CUTLASS"
+
+
class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
@@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
Args: quant_config: The ModelOpt quantization config.
"""
- def __init__(self, quant_config: ModelOptFp8Config):
+ def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
@@ -265,7 +272,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
quant_config: The ModelOpt quantization config.
"""
- def __init__(self, quant_config: ModelOptFp8Config):
+ def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
@@ -670,7 +677,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
- def is_layer_excluded(self, prefix: str, exclude_modules: list):
+ def is_layer_excluded(self, prefix: str,
+ exclude_modules: list[str]) -> bool:
import regex as re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
@@ -714,7 +722,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
Args: quant_config: The ModelOpt quantization config.
"""
- def __init__(self, quant_config: ModelOptNvFp4Config):
+ def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
@@ -859,6 +867,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
return out.view(*output_shape)
+def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
+ # Guess tokens per expert assuming perfect expert distribution first.
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
+ # And pad the number to the next power of 2.
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
+ return tile_tokens_dim
+
+
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
@@ -866,22 +884,40 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: NVFP4 Quant Config
"""
- def __init__(self, quant_config: ModelOptNvFp4Config):
+ def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
- self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
+ self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
+ self.flashinfer_moe_backend = None
- self.fused_experts = None # type: ignore
+ if self.allow_flashinfer:
+ flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
+ if flashinfer_moe_backend == "throughput":
+ self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
+ logger.info_once("Using FlashInfer CUTLASS kernels for "
+ "ModelOptNvFp4FusedMoE.")
+ elif flashinfer_moe_backend == "latency":
+ self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
+ logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
+ "ModelOptNvFp4FusedMoE.")
+ else:
+ allowed_backends = ["throughput", "latency"]
+ raise ValueError(
+ f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
+ f" expected one of {allowed_backends}")
+
+ self.fused_experts: Optional[
+ mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
- if not self.allow_flashinfer_cutlass:
+ if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
@@ -897,8 +933,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
- return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
- logger)
+ return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
def uses_weight_scale_2_pattern(self) -> bool:
"""
@@ -996,14 +1031,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_loader=weight_loader)
layer.register_parameter("w2_input_scale", w2_input_scale)
+ def prepare_static_weight_layouts_for_trtllm_moe(
+ self,
+ gemm1_weights: torch.Tensor,
+ gemm2_weights: torch.Tensor,
+ gemm1_scales_linear_fp4_bytes: torch.Tensor,
+ gemm2_scales_linear_fp4_bytes: torch.Tensor,
+ hidden_size: int,
+ intermediate_size: int,
+ num_experts: int,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Prepare quantized weights for kernel (done offline with weights)."""
+ from flashinfer import (reorder_rows_for_gated_act_gemm,
+ shuffle_matrix_a, shuffle_matrix_sf_a)
+ epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
+
+ # Convert quantized weights to proper formats
+ gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
+ num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4
+ gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
+ torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
+ hidden_size //
+ 16) # fp8 scaling factors
+
+ gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
+ num_experts, hidden_size, intermediate_size // 2) # packed fp4
+ gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
+ torch.float8_e4m3fn).reshape(num_experts, hidden_size,
+ intermediate_size //
+ 16) # fp8 scaling factors
+
+ # Reorder rows of W1 and scales for fused gated activation
+ gemm1_weights_fp4_interleaved = []
+ gemm1_scales_fp4_interleaved = []
+ for i in range(num_experts):
+ gemm1_weights_fp4_interleaved.append(
+ reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
+ gemm1_scales_fp4_interleaved.append(
+ reorder_rows_for_gated_act_gemm(
+ gemm1_scales_linear_fp4[i].clone()))
+
+ # Stack weights and scales for all experts
+ gemm1_weights_fp4_interleaved = torch.stack(
+ gemm1_weights_fp4_interleaved).reshape(num_experts,
+ 2 * intermediate_size,
+ hidden_size // 2)
+ gemm1_scales_fp4_interleaved = torch.stack(
+ gemm1_scales_fp4_interleaved).reshape(num_experts,
+ 2 * intermediate_size,
+ hidden_size // 16)
+
+ # Shuffle weights and scaling factors for transposed mma output
+ gemm1_weights_fp4_shuffled = []
+ gemm1_scales_fp4_shuffled = []
+ gemm2_weights_fp4_shuffled = []
+ gemm2_scales_fp4_shuffled = []
+ for i in range(num_experts):
+ gemm1_weights_fp4_shuffled.append(
+ shuffle_matrix_a(
+ gemm1_weights_fp4_interleaved[i].view(torch.uint8),
+ epilogue_tile_m))
+ gemm1_scales_fp4_shuffled.append(
+ shuffle_matrix_sf_a(
+ gemm1_scales_fp4_interleaved[i].view(torch.uint8),
+ epilogue_tile_m))
+
+ gemm2_weights_fp4_shuffled.append(
+ shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
+ epilogue_tile_m))
+ gemm2_scales_fp4_shuffled.append(
+ shuffle_matrix_sf_a(
+ gemm2_scales_linear_fp4[i].view(torch.uint8),
+ epilogue_tile_m))
+
+ # Stack weights for all experts
+ gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
+ gemm1_scales_fp4_shuffled = (
+ torch.stack(gemm1_scales_fp4_shuffled).view(
+ torch.float8_e4m3fn).reshape(num_experts,
+ 2 * intermediate_size,
+ hidden_size // 16))
+
+ gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
+ gemm2_scales_fp4_shuffled = (
+ torch.stack(gemm2_scales_fp4_shuffled).view(
+ torch.float8_e4m3fn).reshape(num_experts, hidden_size,
+ intermediate_size // 16))
+ return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
+ gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- # GEMM 1
- # The FlashInfer Cutlass fused MoE kernel expects the combined weights
- # to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
+ # GEMM 1 processing
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
- if self.allow_flashinfer_cutlass:
+ if self.allow_flashinfer:
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2)
@@ -1011,6 +1133,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
requires_grad=False)
+ # Common processing for w13_weight_scale_2
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1]):
logger.warning_once(
@@ -1021,26 +1144,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
requires_grad=False)
+ # Common processing for input scales and alphas
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False)
- assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
- "Expected weight_scale.dim(1) to be divisible by 16")
- assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
- "Weight Blockscale must be represented as FP8-E4M3")
- w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
-
- layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
- requires_grad=False)
-
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
- # GEMM 2
+ # GEMM 2 processing
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False)
@@ -1049,15 +1164,63 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
- assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
- "Expected weight_scale.dim(1) to be divisible by 16")
- assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
- "Weight Blockscale must be represented as FP8-E4M3")
- w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
+ # TensorRT-LLM specific processing
+ if self.allow_flashinfer and \
+ self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ # Prepare static weights for TRT-LLM kernel
+ (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
+ gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
+ ) = self.prepare_static_weight_layouts_for_trtllm_moe(
+ layer.w13_weight,
+ layer.w2_weight,
+ layer.w13_weight_scale,
+ layer.w2_weight_scale,
+ layer.w2_weight.size(-2), # hidden_size
+ layer.w13_weight.size(-2) // 2, # intermediate_size
+ layer.w13_weight.size(0), # num_experts
+ )
- layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
- requires_grad=False)
- layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
+ layer.gemm1_weights_fp4_shuffled = Parameter(
+ gemm1_weights_fp4_shuffled, requires_grad=False)
+ layer.gemm2_weights_fp4_shuffled = Parameter(
+ gemm2_weights_fp4_shuffled, requires_grad=False)
+ layer.gemm1_scales_fp4_shuffled = Parameter(
+ gemm1_scales_fp4_shuffled, requires_grad=False)
+ layer.gemm2_scales_fp4_shuffled = Parameter(
+ gemm2_scales_fp4_shuffled, requires_grad=False)
+
+ # Additional parameter needed for TRT-LLM
+ layer.g1_scale_c = Parameter(
+ (layer.w2_input_scale_quant * layer.g1_alphas).to(
+ torch.float32),
+ requires_grad=False,
+ )
+
+ # Clean up weights that won't be used by TRT-LLM
+ del layer.w2_weight
+ del layer.w2_weight_scale
+ del layer.w13_weight
+ del layer.w13_weight_scale
+ else:
+ # Non-TRT-LLM processing (Cutlass or non-flashinfer)
+ assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
+ "Expected weight_scale.dim(1) to be divisible by 16")
+ assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
+ "Weight Blockscale must be represented as FP8-E4M3")
+ w13_blockscale_swizzled = swizzle_blockscale(
+ layer.w13_weight_scale)
+ layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
+ requires_grad=False)
+
+ assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
+ "Expected weight_scale.dim(1) to be divisible by 16")
+ assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
+ "Weight Blockscale must be represented as FP8-E4M3")
+ w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
+ layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
+ requires_grad=False)
+ layer.w2_weight = Parameter(layer.w2_weight.data,
+ requires_grad=False)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
@@ -1095,6 +1258,61 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
assert activation == "silu", "Only SiLU activation is supported."
+ if self.allow_flashinfer and \
+ self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ import flashinfer
+
+ from vllm.model_executor.models.llama4 import Llama4MoE
+
+ a1_gscale = layer.w13_input_scale_quant
+ (hidden_states_fp4,
+ hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
+ x,
+ a1_gscale,
+ is_sf_swizzled_layout=False,
+ )
+ use_llama4_routing = \
+ custom_routing_function is Llama4MoE.custom_routing_function
+ routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
+ if use_llama4_routing:
+ routing_method_type = flashinfer.RoutingMethodType.Llama4
+ out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
+ routing_logits=router_logits
+ if use_llama4_routing else router_logits.to(torch.float32),
+ routing_bias=e_score_correction_bias,
+ hidden_states=hidden_states_fp4,
+ hidden_states_scale=hidden_states_scale_linear_fp4.view(
+ torch.float8_e4m3fn).flatten(),
+ gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
+ gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
+ torch.float8_e4m3fn),
+ gemm1_bias=None,
+ gemm1_alpha=None,
+ gemm1_beta=None,
+ gemm1_clamp_limit=None,
+ gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
+ gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
+ torch.float8_e4m3fn),
+ gemm2_bias=None,
+ output1_scale_scalar=layer.g1_scale_c.data,
+ output1_scale_gate_scalar=layer.g1_alphas.data,
+ output2_scale_scalar=layer.g2_alphas.data,
+ num_experts=global_num_experts,
+ top_k=top_k,
+ n_group=num_expert_group
+ if num_expert_group is not None else 0,
+ topk_group=topk_group if topk_group is not None else 0,
+ intermediate_size=layer.intermediate_size_per_partition,
+ local_expert_offset=layer.ep_rank * layer.local_num_experts,
+ local_num_experts=layer.local_num_experts,
+ routed_scaling_factor=None,
+ tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
+ layer.local_num_experts),
+ routing_method_type=routing_method_type,
+ do_finalize=True,
+ )[0]
+ return out
+
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -1149,6 +1367,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
+ assert self.allow_flashinfer and \
+ self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
@@ -1160,4 +1380,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
+
return out
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index 068af027398ba..03fbcf158338e 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -8,16 +8,19 @@ from torch.nn.parameter import Parameter
from vllm import envs
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
+from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
+ triton_kernel_moe_forward)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
- _can_support_mxfp4)
+ _can_support_mxfp4, _swizzle_mxfp4)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.utils import set_weight_attrs
+from vllm.platforms import current_platform
from vllm.utils import next_power_of_2, round_up
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
@@ -39,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
- return 100
+ return 80
@classmethod
def get_name(cls) -> QuantizationMethods:
@@ -100,11 +103,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
+ # other padding to increase performance
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
+ elif current_platform.is_rocm():
+ intermediate_size_per_partition_after_pad = round_up(
+ intermediate_size_per_partition, 128)
+ else:
+ intermediate_size_per_partition_after_pad = round_up(
+ intermediate_size_per_partition, 64)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
@@ -303,7 +313,41 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
- return
+ else:
+ from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
+
+ w13_bias = layer.w13_bias.to(torch.float32)
+ w2_bias = layer.w2_bias.to(torch.float32)
+
+ layer.w13_bias = Parameter(w13_bias, requires_grad=False)
+ layer.w2_bias = Parameter(w2_bias, requires_grad=False)
+
+ # FIXME warp need to be adjusted based on batch size
+ # only apply to batched mode
+ if self.moe.use_ep:
+ num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
+ else:
+ num_warps = 8
+
+ w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
+ layer.w13_weight, layer.w13_weight_scale, num_warps)
+ w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
+ layer.w2_weight, layer.w2_weight_scale, num_warps)
+
+ self.w13_precision_config = PrecisionConfig(
+ weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
+ self.w2_precision_config = PrecisionConfig(
+ weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
+
+ self.w13_weight_triton_tensor = w13_weight
+ self.w2_weight_triton_tensor = w2_weight
+
+ # need to delete the original weights to save memory on single GPU
+ del layer.w13_weight
+ del layer.w2_weight
+ layer.w13_weight = None
+ layer.w2_weight = None
+ torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
@@ -404,3 +448,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
True, # do finalize
)[0]
return trtllm_gen_output
+ else:
+ return triton_kernel_moe_forward(
+ hidden_states=x,
+ w1=self.w13_weight_triton_tensor,
+ w2=self.w2_weight_triton_tensor,
+ gating_output=router_logits,
+ topk=top_k,
+ renormalize=renormalize,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ w1_bias=layer.w13_bias,
+ w2_bias=layer.w2_bias,
+ w1_precision=self.w13_precision_config,
+ w2_precision=self.w2_precision_config,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ )
diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py
index 83c8a98eac913..38de4b54fb191 100644
--- a/vllm/model_executor/layers/quantization/tpu_int8.py
+++ b/vllm/model_executor/layers/quantization/tpu_int8.py
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter
-ACTIVATION_SCHEMES = ["none"]
+ACTIVATION_SCHEMES = ["none", "dynamic"]
class Int8TpuConfig(QuantizationConfig):
@@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
+ self.quantize_activation = False
+ if self.quant_config.activation_scheme == 'dynamic':
+ self.quantize_activation = True
def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
@@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
try:
- import torch_xla.experimental.xla_quantized_matmul # noqa: F401
+ import torch_xla.experimental.custom_kernel # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
@@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase):
"to run vLLM on TPU.") from err
weight = layer.weight
scale = layer.scale
- out = torch.ops.xla.quantized_matmul(x, weight, scale)
+ out = torch.ops.xla.quantized_matmul_int8(
+ x, weight, scale, quantize_activation=self.quantize_activation)
if bias is not None:
out = out + bias
return out
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index 4c617e226041f..8ef91eeed406f 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward(
def select_nvfp4_gemm_impl(
- allow_flashinfer_cutlass: bool,
+ allow_flashinfer: bool,
moe, # FusedMoEConfig
logger):
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
@@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl(
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
- if allow_flashinfer_cutlass:
- logger.debug_once("Using FlashInferExperts")
+ if allow_flashinfer:
+ flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
+ if flashinfer_backend != "throughput":
+ raise ValueError(
+ f"Only throughput backend is supported for FlashInferExperts, "
+ f"but got {flashinfer_backend}.")
+ logger.debug_once(
+ "Initializing FlashInferExperts with throughput backend.")
return FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index c6f914febc0a2..9fb194767e4a4 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -6,14 +6,22 @@ import torch
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
- from flashinfer import next_positive_power_of_2
- # Guess tokens per expert assuming perfect expert distribution first.
- num_tokens_per_expert = (num_tokens * top_k) // num_experts
- # And pad the number to the next power of 2.
- tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
- # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
+ # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
+ # TODO: Revert this to dynamic calculation once a new version of FlashInfer
+ # with the necessary kernels is released.
+ tile_tokens_dim = 8
+
+ # from flashinfer import next_positive_power_of_2
+
+ # # Guess tokens per expert assuming perfect expert distribution first.
+ # num_tokens_per_expert = (num_tokens * top_k) // num_experts
+ # # And pad the number to the next power of 2.
+ # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
+ # # Cap to 8-64 tokens per CTA tile as it's the range supported by the
+ # # kernel.
+ # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
+
return tile_tokens_dim
diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
index 68a061968aa99..2fb7ef29e4684 100644
--- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
-from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
+from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
logger = init_logger(__name__)
@@ -394,10 +394,8 @@ def per_token_group_quant_fp8(
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
"""
- # TODO(wentao): refactor this
- # use_ue8m0 should be a global flag that could be set by user
if use_ue8m0 is None:
- use_ue8m0 = is_blackwell_deep_gemm_used()
+ use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
index 4a4e199e13187..95eabe149d89c 100644
--- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
@@ -4,11 +4,55 @@ from typing import Callable, Optional
import torch
-from vllm.utils import direct_register_custom_op
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
+
+logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
+def _swizzle_mxfp4(quant_tensor, scale, num_warps):
+ """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
+ """
+ import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
+ from triton_kernels.numerics import InFlexData
+ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
+ from triton_kernels.tensor_details import layout
+ from triton_kernels.tensor_details.layout import StridedLayout
+ if (current_platform.is_cuda()
+ and current_platform.is_device_capability(90)
+ and not is_torch_equal_or_newer("2.8.1")):
+ logger.warning_once(
+ "Mxfp4 on hopper is running on torch < 2.8.1, "
+ "this cause swizling to be disabled, which may "
+ "cause performance degradation. Please upgrade to torch nightly")
+ value_layout, value_layout_opts = StridedLayout, dict()
+ scale_layout, scale_layout_opts = StridedLayout, dict()
+ else:
+ value_layout, value_layout_opts = \
+ layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
+ scale_layout, scale_layout_opts = (
+ layout.make_default_matmul_mxfp4_w_scale_layout(
+ mx_axis=1, num_warps=num_warps))
+ if current_platform.is_cuda() and \
+ current_platform.is_device_capability(100):
+ constraints = {
+ "is_persistent": True,
+ "epilogue_subtile": 1,
+ }
+ opt_flags.update_opt_flags_constraints(constraints)
+ # transpose the tensor so that the quantization axis is on dim1
+ quant_tensor = quant_tensor.transpose(-2, -1)
+ scale = scale.transpose(-2, -1)
+ quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
+ value_layout, **value_layout_opts)
+ scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
+ **scale_layout_opts)
+ return quant_tensor, InFlexData(), scale
+
+
def _can_support_mxfp4(use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
@@ -24,7 +68,7 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
- or scoring_func != "softmax" or activation != "silu"
+ or scoring_func != "softmax" or activation != "swiglu_oai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)
diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
index 23a749467f193..21af74c6b72b5 100644
--- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
+++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
@@ -21,7 +21,7 @@ class NvFp4Support:
"""Result container for NV-FP4 capability probing."""
cutlass_supported: bool
- allow_flashinfer_cutlass: bool
+ allow_flashinfer: bool
use_marlin: bool
@@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
return NvFp4Support(
cutlass_supported=cutlass_supported,
- allow_flashinfer_cutlass=allow_flashinfer,
+ allow_flashinfer=allow_flashinfer,
use_marlin=use_marlin,
)
diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py
index 10fce857a8ae2..6dfc28be7da1a 100644
--- a/vllm/model_executor/layers/rotary_embedding/base.py
+++ b/vllm/model_executor/layers/rotary_embedding/base.py
@@ -8,6 +8,7 @@ import torch
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
+from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled
@CustomOp.register("rotary_embedding")
@@ -35,6 +36,7 @@ class RotaryEmbedding(CustomOp):
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
+ self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
@@ -119,6 +121,75 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style)
return query, key
+ def forward_hip(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ offsets: Optional[torch.Tensor] = None,
+ is_nope_first=False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # currently only rotary embedding ops from AITER package are
+ # supported for HiP forward.
+ if self.is_rocm_aiter_enabled:
+ return self.forward_hip_rocm_aiter(positions, query, key, offsets,
+ is_nope_first)
+ return self.forward_native(positions, query, key, offsets)
+
+ def forward_hip_rocm_aiter(
+ self,
+ positions: torch.Tensor,
+ # if is_nope_first
+ # [[batch_size, seq_len, num_heads, nope_size+rope_size]
+ # if NOT is_nope_first
+ # [[batch_size, seq_len, num_heads, rope_size+nope_size],
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ offsets: Optional[torch.Tensor] = None,
+ is_nope_first: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if self.cos_sin_cache.device != query.device or \
+ self.cos_sin_cache.dtype != query.dtype:
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device,
+ dtype=query.dtype)
+ cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
+
+ cos = cos.unsqueeze(-2).unsqueeze(-2)
+ sin = sin.unsqueeze(-2).unsqueeze(-2)
+
+ rotate_style = 0 if self.is_neox_style else 1
+
+ num_tokens = positions.numel()
+
+ query_shape = query.shape
+ query = query.view(1, num_tokens, -1, self.head_size)
+ if key is not None:
+ key_shape = key.shape
+ key = key.view(1, num_tokens, -1, self.head_size)
+
+ positions = positions.view(*query.shape[:2])
+ if offsets is not None:
+ offsets = offsets.view(*query.shape[:2])
+
+ if not is_nope_first:
+ query_ = query[..., :self.rotary_dim]
+ key_ = key[..., :self.rotary_dim] if key is not None else None
+ else:
+ query_ = query[..., -self.rotary_dim:]
+ key_ = key[..., -self.rotary_dim:] if key is not None else None
+
+ if key_ is None:
+ torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
+ positions, sin, cos, query_, offsets, rotate_style,
+ is_nope_first)
+ return query.view(query_shape), None
+
+ torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
+ positions, sin, cos, query_, key_, offsets, rotate_style,
+ is_nope_first)
+
+ return query.view(query_shape), key.view(key_shape)
+
def forward_xpu(
self,
positions: torch.Tensor,
diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py
index 8d821bea19e3e..99b6bb2120333 100644
--- a/vllm/model_executor/layers/rotary_embedding/common.py
+++ b/vllm/model_executor/layers/rotary_embedding/common.py
@@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
return ramp_func
-def yarn_get_mscale(scale: float = 1) -> float:
+def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
- return 0.1 * math.log(scale) + 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
index cd888b733426b..5af671703a3f4 100644
--- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
+++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import math
from typing import Optional
import torch
@@ -10,13 +9,7 @@ from vllm.platforms import current_platform
from .base import RotaryEmbedding
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
- yarn_linear_ramp_mask)
-
-
-def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
- if scale <= 1:
- return 1.0
- return 0.1 * mscale * math.log(scale) + 1.0
+ yarn_get_mscale, yarn_linear_ramp_mask)
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
@@ -96,6 +89,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
+ if self.is_rocm_aiter_enabled:
+ return self.forward_hip_rocm_aiter(positions, query, key, offsets)
+
assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py
index a75b9e5eb435c..a091cfb743291 100644
--- a/vllm/model_executor/layers/rotary_embedding/mrope.py
+++ b/vllm/model_executor/layers/rotary_embedding/mrope.py
@@ -8,10 +8,176 @@ import numpy as np
import torch
from transformers import PretrainedConfig
+from vllm.platforms import current_platform
+from vllm.triton_utils import tl, triton
+
from .base import RotaryEmbedding
from .common import apply_rotary_emb_dispatch
+@triton.jit
+def _triton_qwen2vl_mrope_forward(
+ q_ptr,
+ k_ptr,
+ cos,
+ sin,
+ num_tokens,
+ n_qh: tl.constexpr,
+ n_kh: tl.constexpr,
+ hd: tl.constexpr,
+ rd: tl.constexpr,
+ pad_n_qh: tl.constexpr,
+ pad_n_kh: tl.constexpr,
+ pad_hd: tl.constexpr,
+ mrope_section_t: tl.constexpr,
+ mrope_section_h: tl.constexpr,
+):
+ # Adapted from
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
+ # This version supports flatten input tensors from vllm
+ # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
+ # instead of (3, bsz, seq_len, head_dim)
+ pid = tl.program_id(0)
+ # locate start address
+ q_ptr = q_ptr + pid * (n_qh * hd)
+ k_ptr = k_ptr + pid * (n_kh * hd)
+
+ # ####################################################################
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
+ # m of this program instance
+ # ####################################################################
+ # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
+
+ t_end = mrope_section_t
+ h_end = t_end + mrope_section_h
+
+ # Updated stride calculation for half head_dim
+ half_rd = rd // 2
+ t_cos = cos + pid * half_rd
+ h_cos = t_cos + num_tokens * half_rd
+ w_cos = h_cos + num_tokens * half_rd
+ t_sin = sin + pid * half_rd
+ h_sin = t_sin + num_tokens * half_rd
+ w_sin = h_sin + num_tokens * half_rd
+
+ # Updated offsets for half head_dim
+ cos_offsets = tl.arange(0, pad_hd // 2)
+ t_mask = cos_offsets < t_end
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
+
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
+
+ cos_row = t_cos_row + h_cos_row + w_cos_row
+ sin_row = t_sin_row + h_sin_row + w_sin_row
+
+ # ####################################################################
+ # Load the left and right half of q and k for the current
+ # program instance (i.e. for the current token) separately
+ # ####################################################################
+ # left half of the head
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(
+ 0, pad_hd // 2)[None, :]
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
+ 0, pad_hd // 2)[None, :]
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
+ 0, pad_hd // 2)[None, :] < rd // 2)
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
+ 0, pad_hd // 2)[None, :] < rd // 2)
+
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
+ mask=first_q_mask,
+ other=0).to(sin_row.dtype)
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets,
+ mask=first_k_mask,
+ other=0).to(sin_row.dtype)
+
+ # right half of the head
+ second_half_q_offsets = first_half_q_offsets + (rd // 2)
+ second_half_k_offsets = first_half_k_offsets + (rd // 2)
+ second_q_mask = first_q_mask
+ second_k_mask = first_k_mask
+
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets,
+ mask=second_q_mask,
+ other=0).to(sin_row.dtype)
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets,
+ mask=second_k_mask,
+ other=0).to(sin_row.dtype)
+
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
+ # Since cos and sin are now half-size,
+ # we use the same cos_row and sin_row for both halves
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
+
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
+
+
+def triton_mrope(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ mrope_section: list[int],
+ head_size: int,
+ rotary_dim: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Qwen2VL mrope kernel.
+
+ Args:
+ query: [num_tokens, num_heads * head_size]
+ key: [num_tokens, num_kv_heads * head_size]
+ cos: [3, num_tokens, head_size //2 ]
+ (T/H/W positions with multimodal inputs)
+ sin: [3, num_tokens, head_size //2 ]
+ (T/H/W positions with multimodal inputs)
+ mrope_section: [t, h, w]
+ head_size: int
+ """
+ n_row, n_q_head_head_dim = q.shape
+ n_q_head = n_q_head_head_dim // head_size
+ n_kv_head = k.shape[1] // head_size
+ pad_hd = triton.next_power_of_2(head_size)
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
+
+ # ensure tensors passed into the kernel are contiguous.
+ # It will be no-op if they are already contiguous
+ q = q.contiguous()
+ k = k.contiguous()
+ cos = cos.contiguous()
+ sin = sin.contiguous()
+
+ _triton_qwen2vl_mrope_forward[(n_row, )](
+ q,
+ k,
+ cos,
+ sin,
+ n_row,
+ n_q_head,
+ n_kv_head,
+ head_size,
+ rotary_dim,
+ pad_n_q_head,
+ pad_n_kv_head,
+ pad_hd,
+ mrope_section[0],
+ mrope_section[1],
+ )
+ return q, k
+
+
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
@@ -36,11 +202,34 @@ class MRotaryEmbedding(RotaryEmbedding):
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
+ self.use_triton = current_platform.is_cuda_alike()
+
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """MRope forward.
+
+ Args:
+ positions:
+ [num_tokens,] (text only) or
+ [3, num_tokens] (T/H/W positions with multimodal inputs)
+ query: [num_tokens, num_heads * head_size]
+ key: [num_tokens, num_kv_heads * head_size]
+ """
+ if self.use_triton:
+ return self.forward_cuda(positions, query, key)
+ else:
+ return self.forward_native(positions, query, key)
+
+ def forward_native(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
@@ -88,6 +277,52 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
+ def forward_cuda(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ offsets: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+
+ assert positions.ndim == 1 or positions.ndim == 2
+ assert key is not None
+
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ query_shape = query.shape
+ key_shape = key.shape
+ if positions.ndim == 2:
+ assert self.mrope_section
+
+ q, k = triton_mrope(
+ query,
+ key,
+ cos,
+ sin,
+ self.mrope_section,
+ self.head_size,
+ self.rotary_dim,
+ )
+
+ return q.reshape(query_shape), k.reshape(key_shape)
+
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., :self.rotary_dim]
+ query_pass = query[..., self.rotary_dim:]
+ query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
+ self.is_neox_style)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., :self.rotary_dim]
+ key_pass = key[..., self.rotary_dim:]
+ key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
+ self.is_neox_style)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return query, key
+
@classmethod
def get_input_positions(
cls,
diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
new file mode 100644
index 0000000000000..91a2318badb40
--- /dev/null
+++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
@@ -0,0 +1,127 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Optional
+
+import torch
+
+import vllm.envs as envs
+from vllm.platforms import current_platform
+from vllm.utils import direct_register_custom_op
+
+
+def is_rocm_rotary_embedding_enabled() -> bool:
+ return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)
+
+
+def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
+ positions: torch.Tensor,
+ sin: torch.Tensor,
+ cos: torch.Tensor,
+ query: torch.Tensor,
+ offsets: Optional[torch.Tensor] = None,
+ rotate_style: int = 0,
+ is_nope_first: bool = False,
+) -> None:
+ import aiter as ops
+ if offsets is None:
+ ops.rope_cached_positions_fwd_inplace(
+ query,
+ cos,
+ sin,
+ positions,
+ rotate_style,
+ reuse_freqs_front_part=True,
+ nope_first=is_nope_first,
+ )
+ else:
+ ops.rope_cached_positions_offsets_fwd_inplace(
+ query,
+ cos,
+ sin,
+ positions,
+ offsets,
+ rotate_style,
+ reuse_freqs_front_part=True,
+ nope_first=is_nope_first,
+ )
+
+
+def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
+ positions: torch.Tensor,
+ sin: torch.Tensor,
+ cos: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ offsets: Optional[torch.Tensor] = None,
+ rotate_style: int = 0,
+ is_nope_first: bool = False,
+) -> None:
+ import aiter as ops
+ if offsets is None:
+ ops.rope_cached_positions_2c_fwd_inplace(
+ query,
+ key,
+ cos,
+ sin,
+ positions,
+ rotate_style,
+ reuse_freqs_front_part=True,
+ nope_first=is_nope_first,
+ )
+ else:
+ ops.rope_cached_positions_offsets_2c_fwd_inplace(
+ query,
+ key,
+ cos,
+ sin,
+ positions,
+ offsets,
+ rotate_style,
+ reuse_freqs_front_part=True,
+ nope_first=is_nope_first,
+ )
+
+
+def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
+ positions: torch.Tensor,
+ sin: torch.Tensor,
+ cos: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ offsets: Optional[torch.Tensor] = None,
+ rotate_style: int = 0,
+ is_nope_first: bool = False,
+) -> None:
+ pass
+
+
+def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
+ positions: torch.Tensor,
+ sin: torch.Tensor,
+ cos: torch.Tensor,
+ query: torch.Tensor,
+ offsets: Optional[torch.Tensor] = None,
+ rotate_style: int = 0,
+ is_nope_first: bool = False,
+) -> None:
+ pass
+
+
+if is_rocm_rotary_embedding_enabled():
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
+ op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
+ mutates_args=["key", "query"],
+ fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
+ op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
+ mutates_args=["query"],
+ fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
\ No newline at end of file
diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py
index cd32f12f3c269..48a347a8f5611 100644
--- a/vllm/model_executor/layers/utils.py
+++ b/vllm/model_executor/layers/utils.py
@@ -11,6 +11,27 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
+def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
+ # Shuffle weight along the last dimension so that
+ # we folded the weights to adjance location
+ # Example:
+ # input:
+ # [[1, 2, 3, 4, 5, 6],
+ # [7, 8, 9, 10, 11, 12]]
+ # output:
+ # [[1, 4, 2, 5, 3, 6],
+ # [7, 10, 8, 11, 9, 12]]
+ # This will be used together with triton swiglu kernel
+ shape = w.shape
+ N = shape[-1]
+ first = w[..., :N // 2]
+ second = w[..., N // 2:]
+
+ stacked = torch.stack((first, second), dim=-1)
+ w_shuffled = stacked.reshape(shape)
+ return w_shuffled
+
+
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py
index ea2fb2e3ac14e..b8393956eed3f 100644
--- a/vllm/model_executor/model_loader/bitsandbytes_loader.py
+++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py
@@ -427,14 +427,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
# TODO: support FusedMoE with prequant and 8bit.
- if self.pre_quant:
+ if self.pre_quant and self.load_8bit:
raise ValueError(
- "Prequant BitsAndBytes models with FusedMoE is not "
- "supported yet.")
- if self.load_8bit:
- raise ValueError(
- "BitsAndBytes 8bit quantization with FusedMoE is not "
- "supported yet.")
+ "Prequant BitsAndBytes 8bit models with FusedMoE "
+ "is not supported yet.")
# Get the corresponding weight name using module name and
# expert_params_mapping.
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 074126fa669e9..78b186265dd04 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -764,39 +764,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return None
return remapped_name
- possible_scale_names = [".k_scale", ".v_scale"]
- modelopt_scale_names = [
- ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
+ # Define scale name mapping patterns in order of precedence
+ scale_mapping_patterns = [
+ # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
+ # .self_attn.attn.{k,v}_scale
+ (r"\.self_attn\.([kv])_proj\.([kv])_scale$",
+ r".self_attn.attn.\2_scale"),
+ # QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
+ # .self_attn.attn.{k,v}_scale
+ (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
+ # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
+ # .self_attn.attn.{k,v}_scale
+ (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"
+ ),
+ # Default format: .{k,v}_scale -> .attn.{k,v}_scale
+ (r"\.([kv])_scale$", r".attn.\1_scale"),
]
- # Also support qkv_proj scale parameters (from stacked parameter processing)
- qkv_proj_scale_names = [
- ".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
- ]
- for scale_name in possible_scale_names:
- if name.endswith(scale_name):
- if any(mo_scale_name in name
- for mo_scale_name in modelopt_scale_names):
- remapped_name = name.replace(
- f".self_attn.{scale_name[1]}_proj{scale_name}",
- f".self_attn.attn{scale_name}")
- elif any(qkv_scale_name in name
- for qkv_scale_name in qkv_proj_scale_names):
- # Handle qkv_proj scale parameters
- remapped_name = name.replace(
- f".self_attn.qkv_proj{scale_name}",
- f".self_attn.attn{scale_name}")
- else:
- remapped_name = name.replace(scale_name, f".attn{scale_name}")
- if remapped_name not in params_dict:
- logger.warning_once(
- "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
- scale_name,
- name,
- remapped_name,
- scale_name,
- )
- return None
- return remapped_name
+
+ # Check if name ends with k_scale or v_scale
+ if name.endswith((".k_scale", ".v_scale")):
+ import regex as re
+
+ for pattern, replacement in scale_mapping_patterns:
+ if re.search(pattern, name):
+ remapped_name = re.sub(pattern, replacement, name)
+ if remapped_name not in params_dict:
+ scale_type = name.split(".")[-1]
+ logger.warning_once(
+ "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
+ scale_type,
+ name,
+ remapped_name,
+ scale_type,
+ )
+ return None
+ return remapped_name
# If there were no matches, return the untouched param name
return name
diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py
index 867de2c68b4c5..1dbe70f84a626 100644
--- a/vllm/model_executor/models/adapters.py
+++ b/vllm/model_executor/models/adapters.py
@@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T:
assert pooler_config is not None
pooling_type_str = pooler_config.pooling_type
- pooling_type = (PoolingType.LAST if pooling_type_str is None else
- PoolingType[pooling_type_str])
+ assert pooling_type_str is not None
+ pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler({
"encode":
diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py
index d2307bb464bdb..b13d863ebb744 100644
--- a/vllm/model_executor/models/aimv2.py
+++ b/vllm/model_executor/models/aimv2.py
@@ -8,7 +8,6 @@ from typing import Optional
import torch
import torch.nn as nn
-from transformers import PretrainedConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -21,12 +20,13 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.transformers_utils.configs.ovis import AIMv2Config
class AIMv2SwiGLUFFN(nn.Module):
- def __init__(self, config: PretrainedConfig,
- quant_config: QuantizationConfig, prefix: str):
+ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
+ prefix: str):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
@@ -57,7 +57,7 @@ class AIMv2SwiGLUFFN(nn.Module):
class AIMv2PatchEmbed(nn.Module):
- def __init__(self, config: PretrainedConfig):
+ def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
config.num_channels,
@@ -75,7 +75,7 @@ class AIMv2PatchEmbed(nn.Module):
class AIMv2ViTPreprocessor(nn.Module):
- def __init__(self, config: PretrainedConfig):
+ def __init__(self, config: AIMv2Config):
super().__init__()
num_patches = (config.image_size // config.patch_size)**2
@@ -93,8 +93,8 @@ class AIMv2ViTPreprocessor(nn.Module):
class AIMv2Attention(nn.Module):
- def __init__(self, config: PretrainedConfig,
- quant_config: QuantizationConfig, prefix: str):
+ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
+ prefix: str):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -141,8 +141,8 @@ class AIMv2Attention(nn.Module):
class AIMv2Block(nn.Module):
- def __init__(self, config: PretrainedConfig,
- quant_config: QuantizationConfig, prefix: str):
+ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
+ prefix: str):
super().__init__()
self.attn = AIMv2Attention(config,
quant_config=quant_config,
@@ -163,7 +163,7 @@ class AIMv2Transformer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
@@ -193,7 +193,7 @@ class AIMv2Transformer(nn.Module):
class AIMv2Model(torch.nn.Module):
def __init__(self,
- config: PretrainedConfig,
+ config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py
index b476a4f918bc3..5cd74bbba4827 100644
--- a/vllm/model_executor/models/aya_vision.py
+++ b/vllm/model_executor/models/aya_vision.py
@@ -16,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
get_optimal_tiled_canvas)
from vllm.config import VllmConfig
-from vllm.jsontree import json_map_leaves
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
@@ -29,6 +28,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
+from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py
index 504621c8abd8f..6638f06f98261 100644
--- a/vllm/model_executor/models/bert.py
+++ b/vllm/model_executor/models/bert.py
@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
-from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
+from .interfaces import (SupportsCrossEncoding, SupportsQuant,
+ default_pooling_type)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -60,21 +61,13 @@ class BertEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
- token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- input_shape = input_ids.size()
- # Input embeddings.
+ token_type_ids = _decode_token_type_ids(input_ids)
+
inputs_embeds = self.word_embeddings(input_ids)
-
- # Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape,
- dtype=torch.long,
- device=inputs_embeds.device)
-
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
@@ -335,6 +328,7 @@ class BertOutput(nn.Module):
@support_torch_compile
+@default_pooling_type("CLS")
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
@@ -350,25 +344,23 @@ class BertModel(nn.Module, SupportsQuant):
) -> None:
super().__init__()
- config = vllm_config.model_config.hf_config
- self.embeddings = embedding_class(config)
+ self.config = vllm_config.model_config.hf_config
+ self.embeddings = embedding_class(self.config)
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
def forward(
self,
input_ids: torch.Tensor,
- position_ids: torch.Tensor,
+ positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids)
+ position_ids=positions)
return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@@ -411,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
return loaded_params
+@default_pooling_type("ALL")
class BertPoolingModel(BertModel):
is_pooling_model = True
@@ -441,6 +434,7 @@ class BertPoolingModel(BertModel):
return loaded_params
+@default_pooling_type("CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
@@ -466,15 +460,13 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def forward(
self,
- input_ids: Optional[torch.Tensor],
+ input_ids: torch.Tensor,
positions: torch.Tensor,
- token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
- position_ids=positions,
- token_type_ids=token_type_ids,
+ positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
@@ -498,18 +490,59 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler({
- "encode":
- Pooler.for_encode(pooler_config),
- "embed":
- Pooler.for_embed(
- pooler_config,
- default_pooling_type=PoolingType.CLS,
- ),
+ "encode": Pooler.for_encode(pooler_config),
+ "embed": Pooler.for_embed(pooler_config),
})
-class BertForSequenceClassification(nn.Module, SupportsV0Only,
- SupportsCrossEncoding, SupportsQuant):
+# Here we encode the token type ids together with the input ids.
+# Since we use int 32 for the input IDs and the vocabulary size
+# is way lower than 2**31, there is room to encode additional
+# bits. At the same time, for cross-encoder use cases, the
+# token type ids are only 0 or 1, requiring only 1 bit.
+# This means that we can store the token type ids in the 31st
+# bit. We void the 32nd bit because that would produce a negative
+# number, which could be used to signal other things.
+#
+# The reason for all of this is that all the tensors that are
+# passed as input to the forward function of a module marked
+# with @support_torch_compile have to be persistent. So to
+# avoid adding more persistent tensors in the model runner, we
+# encode more information in the same persistent tensor.
+#
+# Since the *ForClassification module is outside of the BertModel
+# which is compiled, we can do the encoding here and then separate
+# the information again in the Embedding layer. Since with bit masks
+# we can do this entirely with torch operations and without branching,
+# it works with torch compile.
+
+TOKEN_TYPE_SHIFT = 30
+
+
+def _encode_token_type_ids(input_ids: torch.Tensor,
+ token_type_ids: torch.Tensor) -> None:
+ # input_ids can be padded to the right
+ input_ids[:token_type_ids.shape[0]].bitwise_or_(
+ token_type_ids << TOKEN_TYPE_SHIFT)
+
+
+def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
+
+ ids_mask = torch.ones(input_ids.shape,
+ dtype=torch.int32,
+ device=input_ids.device) << TOKEN_TYPE_SHIFT
+ tokens_mask = ids_mask.bitwise_not()
+
+ token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT
+
+ input_ids.bitwise_and_(tokens_mask)
+
+ return token_type_ids
+
+
+@default_pooling_type("CLS")
+class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
+ SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+
+ if token_type_ids is not None:
+ assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
+ assert input_ids is not None
+ _encode_token_type_ids(input_ids, token_type_ids)
+
return self.bert(input_ids=input_ids,
- position_ids=positions,
+ positions=positions,
inputs_embeds=inputs_embeds,
- intermediate_tensors=intermediate_tensors,
- token_type_ids=token_type_ids)
+ intermediate_tensors=intermediate_tensors)
diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py
index 59033cb74a338..e18b7b7ffabab 100644
--- a/vllm/model_executor/models/bert_with_rope.py
+++ b/vllm/model_executor/models/bert_with_rope.py
@@ -8,13 +8,15 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
+from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
-from vllm.model_executor.layers.fused_moe import fused_moe
+from vllm.model_executor.layers.fused_moe.fused_moe import (
+ fused_topk, torch_vllm_outplace_fused_experts)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -25,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.model_executor.models.interfaces import SupportsQuant
+from vllm.model_executor.models.interfaces import (SupportsQuant,
+ default_pooling_type)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@@ -284,15 +287,22 @@ class NomicMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
- final_hidden_states = fused_moe(hidden_states,
- self.w1,
- self.w2,
- router_logits,
- self.top_k,
- renormalize=False,
- inplace=False,
- activation=self.hidden_act,
- is_act_and_mul=False)
+ # FIXME(Isotr0py): This implementation is too tricky,
+ # we should use FusedMoE instead in the future
+ # after supporting ungated activation for it.
+ topk_weights, topk_ids, _ = fused_topk(hidden_states,
+ router_logits,
+ self.top_k,
+ renormalize=False)
+ final_hidden_states = torch_vllm_outplace_fused_experts(
+ hidden_states=hidden_states,
+ w1=self.w1,
+ w2=self.w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ activation=self.hidden_act,
+ is_act_and_mul=False,
+ )
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -391,6 +401,8 @@ class BertWithRopeEncoder(nn.Module):
return hidden_states
+@support_torch_compile
+@default_pooling_type("CLS")
class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
@@ -407,7 +419,7 @@ class BertWithRope(nn.Module, SupportsQuant):
def forward(
self,
- input_ids: Optional[torch.Tensor],
+ input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
@@ -554,20 +566,6 @@ class JinaRobertaModel(BertWithRope):
"norm2": "mlp_ln",
})
- def forward(
- self,
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- return super().forward(input_ids=input_ids,
- positions=position_ids,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- token_type_ids=token_type_ids)
-
@torch.inference_mode()
def jina_merge_lora_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]):
diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py
new file mode 100644
index 0000000000000..f17583768f795
--- /dev/null
+++ b/vllm/model_executor/models/cohere2_vision.py
@@ -0,0 +1,445 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Adapted from vllm/model_executor/models/aya_vision.py
+"""Command-A-Vision (Cohere2Vision) multimodal model implementation for vLLM."""
+
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Annotated, Literal, Optional, Union
+
+import torch
+from torch import nn
+from transformers import BatchFeature, PretrainedConfig
+from transformers.models.cohere2_vision import Cohere2VisionConfig
+from transformers.models.cohere2_vision.processing_cohere2_vision import (
+ Cohere2VisionProcessor)
+
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.activation import MulAndSilu
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.quantization.awq import AWQConfig
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
+from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
+ MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ MultiModalFieldConfig,
+ PromptReplacement, PromptUpdate,
+ PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .siglip import SiglipVisionModel
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+
+
+class Cohere2VisionImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - np: The total number of patches over each image over each prompt in
+ the batch
+ - c: Number of channels
+ - h: Height of each image patch
+ - w: Width of each image patch
+ - bn: Batch size * number of images
+ """
+
+ type: Literal["pixel_values"]
+
+ pixel_values: Annotated[
+ torch.Tensor,
+ TensorShape("np", 3, "h", "w"),
+ ]
+
+ num_patches: Annotated[
+ torch.Tensor,
+ TensorShape("bn"),
+ ]
+
+
+class Cohere2VisionMultiModalProjector(nn.Module):
+ """Multimodal projector that maps vision features to text embedding space.
+
+ Uses pixel shuffle downsampling followed by SwiGLU activation.
+ """
+
+ def __init__(self, config: Cohere2VisionConfig, prefix: str = ""):
+ super().__init__()
+ self.downsample_factor = config.downsample_factor
+
+ # Input dimension after pixel shuffle downsampling
+ input_dim = config.vision_config.hidden_size * (
+ config.downsample_factor**2)
+ # MergedColumnParallelLinear expects the intermediate size to be a list
+ # of sizes, so that it will load the weights as two separate linear
+ # layers before applying any parallelism.
+ # We need to divide the alignment intermediate size by 2 because
+ # the weights are merged weights of two linear layers for SwiGLU.
+ self.intermediate_size = config.alignment_intermediate_size // 2
+
+ self.linear_1 = MergedColumnParallelLinear(
+ input_dim,
+ [self.intermediate_size] * 2,
+ bias=True,
+ return_bias=False,
+ prefix=f"{prefix}.linear_1",
+ )
+ self.act = MulAndSilu()
+ self.linear_2 = RowParallelLinear(
+ self.intermediate_size,
+ config.text_config.hidden_size,
+ bias=True,
+ return_bias=False,
+ prefix=f"{prefix}.linear_2",
+ )
+
+ def forward(self, image_features):
+ image_features = self.pixel_shuffle(image_features)
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+ def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor:
+ """Apply pixel shuffle downsampling to reduce spatial dimensions.
+
+ Args:
+ image_features: Input tensor of shape [B, S, D] where S = H*W
+
+ Returns:
+ Downsampled tensor with increased channel dimension
+ """
+ height = width = int(image_features.shape[1]**0.5)
+ x = image_features.reshape(image_features.shape[0], width, height, -1)
+ n, h, w, c = x.size()
+ scale_factor = 1. / self.downsample_factor
+ nh = int(h * scale_factor)
+ nw = int(w * scale_factor)
+ x = x.reshape(n, nh, self.downsample_factor, nw,
+ self.downsample_factor, c)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ x = x.reshape(n, nh, nw, -1)
+ return x
+
+
+class Cohere2VisionProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self) -> Cohere2VisionConfig:
+ return self.ctx.get_hf_config(Cohere2VisionConfig)
+
+ def get_hf_processor(self, **kwargs: object) -> Cohere2VisionProcessor:
+ return self.ctx.get_hf_processor(Cohere2VisionProcessor, **kwargs)
+
+ def get_image_processor(self, **kwargs: object):
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ image_processor = self.get_image_processor()
+ height = image_processor.size['height']
+ width = image_processor.size['width']
+ max_patches = image_processor.max_patches
+ return ImageSize(height=height * max_patches, width=width)
+
+ def get_num_patches(self, image_width: int, image_height: int) -> int:
+ """
+ Calculate the number of image patches for a given image.
+ Uses the HF processor to determine the actual number of patches.
+ """
+ return self.get_hf_processor(
+ ).image_processor.get_number_of_image_patches(image_height,
+ image_width, {})
+
+
+class Cohere2VisionDummyInputsBuilder(
+ BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ processor = self.info.get_hf_processor()
+ image_token = processor.image_token
+
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ image_size = \
+ self.info.get_image_size_with_most_features()
+
+ return {
+ "image":
+ self._get_dummy_images(width=image_size.width,
+ height=image_size.height,
+ num_images=num_images)
+ }
+
+
+class Cohere2VisionMultiModalProcessor(
+ BaseMultiModalProcessor[Cohere2VisionProcessingInfo]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ processed_outputs = super()._call_hf_processor(
+ prompt,
+ mm_data,
+ mm_kwargs,
+ tok_kwargs,
+ )
+
+ # Ensure num_patches is available for proper tensor splitting
+ if "num_patches" not in processed_outputs and (
+ images := mm_data.get("images")) is not None:
+ # Fallback calculation if HF processor didn't provide num_patches
+ parsed_images = self._get_data_parser().parse_mm_data({
+ "image":
+ images
+ }).get_items("image", ImageProcessorItems)
+
+ num_patches = [
+ self.info.get_num_patches(
+ image_width=parsed_images.get_image_size(i).width,
+ image_height=parsed_images.get_image_size(i).height)
+ for i in range(len(parsed_images))
+ ]
+ processed_outputs["num_patches"] = torch.tensor(num_patches)
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ num_patches = hf_inputs.get("num_patches", torch.empty(0))
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches),
+ num_patches=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_token = hf_processor.image_token
+ img_line_break_token = hf_processor.img_line_break_token
+ boi_token = hf_processor.boi_token
+ eoi_token = hf_processor.eoi_token
+
+ def get_replacement(item_idx: int):
+ images: ImageProcessorItems = mm_items.get("image",
+ ImageProcessorItems)
+ image_size: ImageSize = images.get_image_size(item_idx)
+
+ num_patches = self.info.get_num_patches(image_size.height,
+ image_size.width)
+ img_tokens_per_tile = int(hf_processor.patch_size**2)
+ single_tile_tokens = image_token * img_tokens_per_tile + \
+ img_line_break_token
+ img_string = f"{boi_token}\
+ {single_tile_tokens * num_patches}\
+ {eoi_token}"
+
+ return PromptUpdateDetails.select_text(img_string, image_token)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_replacement,
+ )
+ ]
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Cohere2VisionMultiModalProcessor,
+ info=Cohere2VisionProcessingInfo,
+ dummy_inputs=Cohere2VisionDummyInputsBuilder)
+class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsPP):
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "model.vision_tower.": "vision_tower.",
+ "model.multi_modal_projector.": "multi_modal_projector.",
+ "model.language_model.": "language_model.model.",
+ "lm_head.": "language_model.lm_head.",
+ })
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: Cohere2VisionConfig = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ self.config = config
+ self.quant_config = quant_config
+ self.multimodal_config = multimodal_config
+ self._patch_quant_config(config, quant_config)
+
+ self.vision_tower = SiglipVisionModel(config.vision_config,
+ quant_config,
+ prefix=maybe_prefix(
+ prefix, "vision_tower"))
+ self.vocab_size = config.text_config.vocab_size
+ self.multi_modal_projector = \
+ Cohere2VisionMultiModalProjector(
+ config, prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Cohere2ForCausalLM"])
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs,
+ **kwargs) -> list[torch.Tensor]:
+ """Process image pixels through vision tower and projector.
+
+ Args:
+ image_input: Validated image input containing pixel values and
+ patch counts
+
+ Returns:
+ List of flattened image embeddings, one per image
+ """
+ assert self.vision_tower is not None, "Vision tower is required"
+
+ pixel_values = image_input["pixel_values"]
+ num_patches = image_input["num_patches"]
+
+ # Extract visual features
+ image_features = self.vision_tower(pixel_values)
+
+ # Project to text embedding space
+ image_embeds = self.multi_modal_projector(image_features)
+
+ # Split and flatten embeddings per image
+ return [
+ e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
+ ]
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Cohere2VisionImagePixelInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ num_patches = kwargs.pop("num_patches", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ assert image_embeds is None, \
+ "Cohere2Vision does not support image_embeds."
+
+ if pixel_values is None:
+ return None
+
+ return Cohere2VisionImagePixelInputs(
+ type="pixel_values",
+ pixel_values=flatten_bn(pixel_values, concat=True),
+ num_patches=flatten_bn(num_patches, concat=True),
+ resolve_bindings={
+ "h": self.config.vision_config.image_size,
+ "w": self.config.vision_config.image_size,
+ })
+
+ def _patch_quant_config(self, config: PretrainedConfig,
+ quant_config: QuantizationConfig):
+ # the awq models from OpenGVLab missing `modules_to_not_convert`
+ # patch the quant_config to add `modules_to_not_convert` back
+ if isinstance(quant_config, AWQConfig):
+ text_config = config.text_config
+ llm_quant_config = getattr(text_config, "quantization_config",
+ None)
+ if (not quant_config.modules_to_not_convert) and (llm_quant_config
+ is not None):
+ quant_config.modules_to_not_convert.append("vision_tower")
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def get_multimodal_embeddings(self,
+ **kwargs: object) -> MultiModalEmbeddings:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return []
+
+ return self._process_image_input(image_input, **kwargs)
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None \
+ and len(multimodal_embeddings) != 0:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ placeholder_token_id=self.config.image_token_id,
+ )
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ hidden_states = self.language_model.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index c4f6144ed91f0..4dd84b8f8fdd5 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -27,7 +27,7 @@ from typing import Optional, Union
import torch
from torch import nn
-from transformers import CohereConfig
+from transformers import Cohere2Config, CohereConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -89,7 +89,7 @@ class CohereMLP(nn.Module):
def __init__(
self,
- config: CohereConfig,
+ config: Union[CohereConfig, Cohere2Config],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -124,7 +124,7 @@ class CohereAttention(nn.Module):
def __init__(
self,
- config: CohereConfig,
+ config: Union[CohereConfig, Cohere2Config],
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -182,21 +182,13 @@ class CohereAttention(nn.Module):
)
# Model v2 has interleaved sliding windows, v1 does not
- interleaved_sliding_window = getattr(config,
- "interleaved_sliding_window",
- None)
- self.v1 = interleaved_sliding_window is None
+ self.v1 = isinstance(config, CohereConfig)
- layer_idx = extract_layer_index(prefix)
- layer_has_sliding_window = (
- getattr(config, "sliding_window_pattern", False) and
- (layer_idx + 1) % self.config.sliding_window_pattern
- != 0) or (getattr(config, "layer_types", False)
- and config.layer_types[layer_idx] == "sliding_attention")
-
- self.sliding_window = (interleaved_sliding_window
- or config.sliding_window
- if layer_has_sliding_window else None)
+ self.sliding_window = None
+ if not self.v1:
+ layer_idx = extract_layer_index(prefix)
+ if config.layer_types[layer_idx] == "sliding_attention":
+ self.sliding_window = config.sliding_window
self.attn = Attention(self.num_heads,
self.head_dim,
@@ -242,7 +234,7 @@ class CohereAttention(nn.Module):
class CohereDecoderLayer(nn.Module):
def __init__(self,
- config: CohereConfig,
+ config: Union[CohereConfig, Cohere2Config],
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 360c7e66bf5ce..e74d90e0b1d7d 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -6,7 +6,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
-from transformers import PretrainedConfig
+from transformers import DbrxConfig
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
@@ -39,7 +39,7 @@ class DbrxRouter(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: DbrxConfig,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
@@ -63,7 +63,7 @@ class DbrxExperts(FusedMoE):
def __init__(
self,
- config: PretrainedConfig,
+ config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
@@ -138,7 +138,7 @@ class DbrxMoE(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
@@ -169,7 +169,7 @@ class DbrxAttention(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -249,7 +249,7 @@ class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -284,7 +284,7 @@ class DbrxBlock(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index c2880c33cb65d..f199da135ec76 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -29,7 +29,7 @@ from typing import Any, Optional, Union
import torch
from torch import nn
-from transformers import PretrainedConfig
+from transformers import DeepseekV2Config, DeepseekV3Config
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -100,7 +100,7 @@ class DeepseekV2MoE(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
@@ -221,7 +221,7 @@ class DeepseekV2Attention(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
@@ -373,7 +373,7 @@ class DeepseekV2MLAAttention(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
@@ -538,7 +538,7 @@ class DeepseekV2DecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
@@ -973,7 +973,10 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
-def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
+# Compatibility with
+# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
+def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
+ DeepseekV3Config],
weight_name: str) -> Optional[int]:
if (hasattr(config, "num_nextn_predict_layers")
and config.num_nextn_predict_layers > 0):
diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py
index 9b21a79446138..5f410c0ae5fb0 100644
--- a/vllm/model_executor/models/dots1.py
+++ b/vllm/model_executor/models/dots1.py
@@ -29,7 +29,7 @@ from typing import Any, Optional, Union
import torch
from torch import nn
-from transformers import PretrainedConfig
+from transformers import Dots1Config
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -99,7 +99,7 @@ class Dots1MoE(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Dots1Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -174,7 +174,7 @@ class Dots1Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
- config: PretrainedConfig,
+ config: Dots1Config,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
@@ -260,7 +260,7 @@ class Dots1DecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Dots1Config,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py
index 3d6ce3e8895fb..827e9014184b5 100644
--- a/vllm/model_executor/models/exaone4.py
+++ b/vllm/model_executor/models/exaone4.py
@@ -26,7 +26,7 @@ from typing import Any, Optional, Union
import torch
from torch import nn
-from transformers import PretrainedConfig
+from transformers import Exaone4Config
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -96,7 +96,7 @@ class Exaone4Attention(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Exaone4Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
@@ -159,25 +159,12 @@ class Exaone4Attention(nn.Module):
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
- self.apply_all_layers = False # apply rotary embeddings to every layer.
layer_idx = extract_layer_index(prefix)
- interleaved_sliding_window = getattr(config,
- "interleaved_sliding_window",
- 4096)
- sliding_window_pattern = getattr(config, "sliding_window_pattern",
- "LLLG")
+ is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ self.sliding_window = config.sliding_window if is_sliding else None
- if sliding_window_pattern:
- layer_has_sliding_window = (
- layer_idx + 1) % sliding_window_pattern.__len__() != 0
- else:
- layer_has_sliding_window = False
- self.apply_all_layers = True
-
- if layer_has_sliding_window:
- self.sliding_window = interleaved_sliding_window
- else:
- self.sliding_window = None
+ # apply rotary embeddings to every layer
+ self.apply_all_layers = not is_sliding
self.rotary_emb = get_rope(
self.head_dim,
@@ -224,7 +211,7 @@ class Exaone4DecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Exaone4Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index 8beefb2cd0bd8..8cfe92c64540f 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -144,13 +144,10 @@ class Gemma2Attention(nn.Module):
is_neox_style=True,
)
- # reference:
- # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
- use_sliding_window = (layer_idx % 2 == 0 and getattr(
- config, "interleaved_sliding_window", None) is not None)
- sliding_window = config.interleaved_sliding_window if \
- use_sliding_window else None
+ is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ sliding_window = config.sliding_window if is_sliding else None
+
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py
index 1a2ce65d1e4c2..b762be3c52925 100644
--- a/vllm/model_executor/models/gemma3.py
+++ b/vllm/model_executor/models/gemma3.py
@@ -146,25 +146,19 @@ class Gemma3Attention(nn.Module):
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
- # TODO(woosuk): Add reference to the original HF implementation.
layer_idx = extract_layer_index(prefix)
- self.is_sliding = (getattr(
- config, "interleaved_sliding_window", None) is not None and (bool(
- (layer_idx + 1) % config.sliding_window_pattern))) or (
- getattr(config, "layer_types", None) is not None
- and config.layer_types[layer_idx] == "sliding_attention")
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ sliding_window = config.sliding_window if self.is_sliding else None
+
# Initialize the rotary embedding.
if self.is_sliding:
# Local attention. Override the values in config.json.
self.rope_theta = config.rope_local_base_freq
self.rope_scaling = {"rope_type": "default"}
- self.sliding_window = (config.interleaved_sliding_window
- or config.sliding_window)
else:
# Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling
- self.sliding_window = None
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
@@ -182,7 +176,7 @@ class Gemma3Attention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
- per_layer_sliding_window=self.sliding_window,
+ per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")
def forward(
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index e9ee1ebdcc680..9871b11b37991 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -502,8 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
- self.sliding_window = getattr(config.text_config,
- "interleaved_sliding_window", None)
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
@@ -690,11 +688,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
- if self.sliding_window is not None:
+ if (sliding_window := self.config.sliding_window) is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
- diagonal=-self.sliding_window)
+ diagonal=-sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py
index e16c03c8d3b57..ffec3408702c9 100644
--- a/vllm/model_executor/models/gemma3n.py
+++ b/vllm/model_executor/models/gemma3n.py
@@ -313,17 +313,16 @@ class Gemma3nAttention(nn.Module):
has_weight=False)
layer_idx = extract_layer_index(prefix)
+ is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ self.sliding_window = config.sliding_window if is_sliding else None
- is_sliding_window = (
- getattr(config, "interleaved_sliding_window", None) is not None
- and config.layer_types[layer_idx] == "sliding_attention")
-
- if is_sliding_window:
- self.sliding_window = config.interleaved_sliding_window
+ # Initialize the rotary embedding.
+ if is_sliding:
+ # Local attention. Override the values in config.json.
rope_theta = config.rope_local_base_freq
rope_scaling = {"rope_type": "default"}
else:
- self.sliding_window = None
+ # Global attention. Use the values in config.json.
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
@@ -331,14 +330,15 @@ class Gemma3nAttention(nn.Module):
config.num_kv_shared_layers)
self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
+ kv_sharing_target_layer_name = None
if self.is_kv_shared:
# Last full attention layer is 1 before sharing
# Last sliding attention layer is 2 before sharing
offset = 2 if self.sliding_window is not None else 1
kv_shared_layer_index = first_kv_shared_layer_idx - offset
- kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
- else:
- kv_sharing_target_layer_name = None
+ if kv_shared_layer_index >= 0:
+ # Only the greater layer is required to specify sharing.
+ kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
self.rotary_emb = get_rope(
self.head_dim,
@@ -396,6 +396,7 @@ class Gemma3nDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
+ assert isinstance(config, Gemma3nTextConfig)
self.altup_active_idx = config.altup_active_idx
assert config.altup_correct_scale
@@ -537,7 +538,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
- config = vllm_config.model_config.hf_config.text_config
+ config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
@@ -553,6 +554,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
config.hidden_size**0.5,
dtype=self.embed_tokens.weight.dtype,
)
+ # Additional per-layer embeddings (PLE)
self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
@@ -636,6 +638,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
+ per_layer_inputs: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -644,13 +648,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
- # Per layer inputs.
- if input_ids is None:
- raise ValueError("Passing None for input ids is not supported.")
- per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
- per_layer_inputs = per_layer_inputs.reshape(
- -1, self.config.num_hidden_layers,
- self.config.hidden_size_per_layer_input)
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
@@ -659,8 +656,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
- per_layer_inputs = per_layer_projection + per_layer_inputs
- per_layer_inputs *= self.per_layer_input_scale
+
+ if per_layer_inputs is not None:
+ # Profiling run does not compute per_layer_inputs
+ per_layer_inputs = per_layer_projection + per_layer_inputs
+ per_layer_inputs *= self.per_layer_input_scale
+ else:
+ per_layer_inputs = per_layer_projection
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
@@ -760,29 +762,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return loaded_params
-class Gemma3nModel(nn.Module):
-
- def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
- prefix=maybe_prefix(
- prefix, "language_model"))
-
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- positions: torch.Tensor,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> torch.Tensor:
- return self.language_model(input_ids=input_ids,
- positions=positions,
- inputs_embeds=inputs_embeds,
- **kwargs)
-
-
-class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
+class Gemma3nForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -802,25 +782,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
super().__init__()
self.config = config
self.cache_config = vllm_config.cache_config
- self.model = Gemma3nModel(vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "model"))
+ self.model = Gemma3nTextModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
- config.text_config.vocab_size,
- soft_cap=config.text_config.final_logit_softcapping)
+ config.vocab_size, soft_cap=config.final_logit_softcapping)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
- return self.model.language_model.get_input_embeddings(input_ids)
+ return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
+ *,
+ per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
- hidden_states = self.model(input_ids, positions, intermediate_tensors,
- inputs_embeds, **kwargs)
+
+ hidden_states = self.model(
+ input_ids,
+ positions,
+ per_layer_inputs=per_layer_inputs,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
return hidden_states
def compute_logits(
@@ -828,8 +816,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata],
) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.model.language_model.embed_tokens,
- hidden_states, sampling_metadata)
+ logits = self.logits_processor(self.model.embed_tokens, hidden_states,
+ sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py
new file mode 100644
index 0000000000000..a0c3bb50070b3
--- /dev/null
+++ b/vllm/model_executor/models/gemma3n_mm.py
@@ -0,0 +1,700 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any, Optional, TypedDict, Union, cast
+
+import torch
+from torch import nn
+from transformers import AutoModel, BatchFeature
+from transformers.models.gemma3n import (Gemma3nAudioConfig,
+ Gemma3nAudioFeatureExtractor,
+ Gemma3nConfig, Gemma3nProcessor,
+ Gemma3nTextConfig,
+ Gemma3nVisionConfig)
+from transformers.models.siglip import SiglipImageProcessorFast
+
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import RowParallelLinear
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ VocabParallelEmbedding)
+from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs)
+from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
+ MultiModalDataParser)
+# yapf: disable
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, BoundPromptUpdate,
+ PlaceholderFeaturesInfo,
+ PromptReplacement, PromptTargetMatch,
+ PromptUpdate, PromptUpdateDetails,
+ find_mm_placeholders,
+ replace_token_matches)
+# yapf: enable
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+
+logger = init_logger(__name__)
+
+# This should be based on model config but we hardcode them for now.
+TOKENS_PER_IMAGE = 256
+TOKENS_PER_AUDIO = 188
+
+
+class Gemma3nImagePixelInputs(TypedDict):
+ pixel_values: torch.Tensor
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
+
+
+class Gemma3nAudioInputs(TypedDict):
+ input_features: torch.Tensor
+ """Shape: `(batch_size * num_audio, seq_length, num_features)`"""
+ input_features_mask: torch.Tensor
+ """Shape: `(batch_size * num_audio, seq_length)`"""
+
+
+Gemma3nImageInputs = Gemma3nImagePixelInputs
+
+
+class Gemma3nProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Gemma3nConfig)
+
+ def get_hf_processor(self, **kwargs: object):
+ return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None, "audio": None}
+
+ def get_max_tokens_per_item(
+ self, seq_len: int,
+ mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
+
+ return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
+
+ def get_image_repl(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[Gemma3nProcessor],
+ ) -> str:
+ """
+ Get the replacement text for image tokens.
+
+ For Gemma3n, this should return the full_image_sequence which includes
+ BOI token, repeated image tokens, and EOI token.
+ """
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ return PromptUpdateDetails.select_token_id(
+ processor.full_image_sequence, processor.image_token_id)
+
+ def get_audio_repl(
+ self,
+ *,
+ processor: Optional[Gemma3nProcessor],
+ ) -> str:
+ """
+ Get the replacement text for audio tokens.
+
+ For Gemma3n, this should return the full_audio_sequence which includes
+ BOA token, repeated audio tokens, and EOA token.
+ """
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ # Return the full audio sequence as defined by the processor
+ return PromptUpdateDetails.select_token_id(
+ processor.full_audio_sequence, processor.audio_token_id)
+
+
+class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ num_audios = mm_counts.get("audio", 0)
+
+ processor = self.info.get_hf_processor()
+ image_token = processor.image_token
+ audio_token = processor.audio_token
+
+ return image_token * num_images + audio_token * num_audios
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ num_audios = mm_counts.get("audio", 0)
+ processor = self.info.get_hf_processor()
+ audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
+ audio_len = audio_feature_extractor.fft_length
+ image_processor: SiglipImageProcessorFast = processor.image_processor
+ img_width = image_processor.size.get("width", 224)
+ img_height = image_processor.size.get("height", 224)
+
+ return {
+ "image":
+ self._get_dummy_images(width=img_width,
+ height=img_height,
+ num_images=num_images),
+ "audio":
+ self._get_dummy_audios(length=audio_len, num_audios=num_audios)
+ }
+
+
+class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
+ ):
+
+ def _get_data_parser(self) -> MultiModalDataParser:
+ feature_extractor = self.info.get_hf_processor().feature_extractor
+ return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+
+ # HF Transformers audio processor no longer accepts `audios` key.
+ # We pop `audios` and replace it with `audio` key to surpress
+ # the warning.
+ if 'audios' in mm_data:
+ mm_data['audio'] = mm_data.pop('audios')
+ processed_outputs = super()._call_hf_processor(
+ prompt,
+ mm_data,
+ mm_kwargs,
+ tok_kwargs,
+ )
+ if 'input_features' in processed_outputs:
+ # Avoid padding since we need the output of each item to be
+ # independent of other items for the cache to work correctly
+ unpadded_features = [
+ f[mask] for f, mask in zip(
+ processed_outputs["input_features"],
+ processed_outputs["input_features_mask"],
+ )
+ ]
+ processed_outputs["input_features"] = unpadded_features
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"),
+ input_features=MultiModalFieldConfig.batched("audio"),
+ input_features_mask=MultiModalFieldConfig.batched("audio"))
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ prompt_updates = []
+
+ # Handle image tokens
+ if "image" in mm_items:
+ image_token = hf_processor.image_token
+
+ def get_replacement_image(item_idx: int):
+ images = mm_items.get_items("image", ImageProcessorItems)
+ image_size = images.get_image_size(item_idx)
+ return self.info.get_image_repl(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ )
+
+ prompt_updates.append(
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_replacement_image,
+ ))
+
+ # Handle audio tokens
+ if "audio" in mm_items:
+ audio_token = hf_processor.audio_token
+
+ def get_replacement_audio(item_idx: int):
+ return self.info.get_audio_repl(processor=hf_processor, )
+
+ prompt_updates.append(
+ PromptReplacement(
+ modality="audio",
+ target=audio_token,
+ replacement=get_replacement_audio,
+ ))
+
+ return prompt_updates
+
+ def _apply_token_matches(
+ self,
+ prompt: list[int],
+ mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
+ mm_item_counts: Mapping[str, int],
+ ) -> list[int]:
+ token_ids = super()._apply_token_matches(
+ prompt,
+ mm_matches,
+ mm_item_counts,
+ )
+
+ # "\n\n\n" and "\n\n\n\n" are single tokens
+ # Since our replacement can insert "\n\n" next to "\n"
+ # tokens, we have to combine them to be consistent with
+ # the output of the tokenizer
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+ newline_1 = vocab["\n"]
+ newline_2 = vocab["\n\n"]
+ newline_3 = vocab["\n\n\n"]
+ newline_4 = vocab["\n\n\n\n"]
+
+ token_ids = replace_token_matches(
+ token_ids,
+ [newline_1, newline_2],
+ [newline_3],
+ )
+ token_ids = replace_token_matches(
+ token_ids,
+ [newline_2, newline_1],
+ [newline_3],
+ )
+ token_ids = replace_token_matches(
+ token_ids,
+ [newline_2, newline_2],
+ [newline_4],
+ )
+
+ return token_ids
+
+ def _find_mm_placeholders(
+ self,
+ mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
+ new_token_ids: list[int],
+ mm_item_counts: Mapping[str, int],
+ ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
+ # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+ newline_1 = vocab["\n"]
+ newline_2 = vocab["\n\n"]
+ newline_3 = vocab["\n\n\n"]
+ newline_4 = vocab["\n\n\n\n"]
+
+ def get_repl_toks(tok: int) -> list[int]:
+ if tok == newline_3:
+ return [newline_1, newline_2]
+ if tok == newline_4:
+ return [newline_2, newline_2]
+
+ return [tok]
+
+ repl_token_ids = list[int]()
+ repl_orig_idxs = list[int]()
+ for orig_idx, orig_tok in enumerate(new_token_ids):
+ repl_toks = get_repl_toks(orig_tok)
+ repl_token_ids.extend(repl_toks)
+ repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
+
+ repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
+ mm_item_counts)
+
+ return {
+ modality: [
+ PlaceholderFeaturesInfo(
+ modality=p.modality,
+ item_idx=p.item_idx,
+ start_idx=repl_orig_idxs[p.start_idx],
+ tokens=p.tokens,
+ is_embed=p.is_embed,
+ ) for p in placeholders
+ ]
+ for modality, placeholders in repls.items()
+ }
+
+
+class Gemma3nMultimodalEmbedder(nn.Module):
+ """Embeds token ids or soft tokens for multimodal content into language
+ model space."""
+
+ def __init__(
+ self,
+ multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
+ text_config: Gemma3nTextConfig,
+ ):
+ super().__init__()
+
+ self.multimodal_hidden_size = multimodal_config.hidden_size
+ self.eps = multimodal_config.rms_norm_eps
+ self.vocab_offset = multimodal_config.vocab_offset
+ self.vocab_size = multimodal_config.vocab_size
+ self.text_hidden_size = text_config.hidden_size
+
+ self.embedding = VocabParallelEmbedding(
+ self.vocab_size,
+ self.multimodal_hidden_size,
+ )
+
+ self.hard_embedding_norm = RMSNorm(
+ self.multimodal_hidden_size,
+ eps=self.eps,
+ )
+
+ self.soft_embedding_norm = RMSNorm(
+ self.multimodal_hidden_size,
+ eps=self.eps,
+ )
+
+ self.embedding_projection = RowParallelLinear(
+ self.multimodal_hidden_size,
+ self.text_hidden_size,
+ bias=False,
+ )
+
+ self.embedding_post_projection_norm = RMSNorm(
+ self.text_hidden_size,
+ eps=self.eps,
+ has_weight=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Embeds token ids or soft tokens for multimodal content into language model space.
+
+ Args:
+ input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
+ `[vocab_offset, vocab_offset + vocab_size)`.
+ inputs_embeds: A torch.Tensor containing the soft tokens to embed.
+
+ Returns:
+ A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
+ """ # noqa: E501
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is not None:
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
+ else:
+ hard_emb = self.embedding(input_ids - self.vocab_offset)
+ emb_norm = self.hard_embedding_norm(hard_emb)
+
+ emb_norm_proj, _ = self.embedding_projection(emb_norm)
+ return self.embedding_post_projection_norm(emb_norm_proj)
+
+
+@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
+ info=Gemma3nProcessingInfo,
+ dummy_inputs=Gemma3nDummyInputsBuilder)
+class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # mapping for new names in checkpoint saved after transformers v4.52
+ "model.embed_audio.": "embed_audio.",
+ "model.embed_vision.": "embed_vision.",
+ "model.language_model.": "language_model.model.",
+ "model.vision_tower.": "vision_tower.",
+ "model.audio_tower.": "audio_tower.",
+ "model.multi_modal_projector.": "multi_modal_projector.",
+ "lm_head.": "language_model.lm_head.",
+ "model": "language_model.model",
+ })
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ self.config = config
+ self.quant_config = quant_config
+ self.multimodal_config = multimodal_config
+ self.vocab_size = config.text_config.vocab_size
+
+ self.sliding_window = getattr(config.text_config,
+ "interleaved_sliding_window", None)
+
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.audio_tower = AutoModel.from_config(config=config.audio_config)
+ self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
+ config.text_config)
+ self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
+ config.text_config)
+
+ self.language_model: nn.Module = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Gemma3nForCausalLM"],
+ )
+ self.language_model = cast(Gemma3nForCausalLM, self.language_model)
+ # NOTE (NickLucche) In order to be compatible with cudagraph, the
+ # buffer needs to be consistent, so we pre-allocate here.
+ self.per_layer_embeddings = torch.zeros(
+ vllm_config.scheduler_config.max_num_batched_tokens,
+ self.config.text_config.num_hidden_layers,
+ self.config.text_config.hidden_size_per_layer_input,
+ device=self.language_model.model.embed_tokens.weight.device,
+ dtype=self.language_model.model.embed_tokens.weight.dtype)
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
+ # TODO check if there are any
+ return data
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ # TODO is this the case?
+ assert image_embeds is None, "Gemma3n does not support image_embeds."
+ if pixel_values is None:
+ return None
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ pixel_values = flatten_bn(pixel_values, concat=True)
+ pixel_values = pixel_values.contiguous()
+
+ return Gemma3nImagePixelInputs(
+ pixel_values=self._validate_pixel_values(pixel_values), )
+
+ def _parse_and_validate_audio_input(
+ self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
+ input_features = kwargs.pop("input_features", None)
+ if input_features is None:
+ return None
+
+ input_features_mask = kwargs.pop("input_features_mask", None)
+ if input_features_mask is None:
+ return None
+
+ return Gemma3nAudioInputs(
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ )
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values", "image_embeds"
+ ) and "image" not in mm_input_by_modality:
+ mm_input_by_modality[
+ "image"] = self._parse_and_validate_image_input(**kwargs)
+ if input_key == "input_features" \
+ and "audio" not in mm_input_by_modality:
+ mm_input_by_modality[
+ "audio"] = self._parse_and_validate_audio_input(**kwargs)
+ return mm_input_by_modality
+
+ def _process_image_input(
+ self,
+ image_input: Gemma3nImageInputs,
+ ) -> list[torch.Tensor]:
+ assert self.vision_tower is not None
+
+ pixel_values = image_input["pixel_values"]
+ vision_outputs = self.vision_tower(pixel_values=pixel_values,
+ do_pooling=False,
+ return_dict=True).last_hidden_state
+ # TODO try to avoid copy here
+ # (batch, channels, height, width) to (batch, height * width, channels)
+ vision_outputs = vision_outputs.reshape(
+ vision_outputs.shape[0],
+ self.config.vision_config.hidden_size,
+ self.config.vision_soft_tokens_per_image,
+ ).permute(0, 2, 1).contiguous()
+ # Normalize and embed the soft tokens into language model space.
+ vision_outputs *= self.config.vision_config.hidden_size**0.5
+ # Return a list of embeddings instead of a batched tensor
+ return self.embed_vision(inputs_embeds=vision_outputs).unbind(0)
+
+ def _process_audio_input(
+ self,
+ audio_input: Gemma3nAudioInputs,
+ ) -> list[torch.Tensor]:
+ assert self.audio_tower is not None
+ input_features = audio_input["input_features"].squeeze(1)
+ input_features_mask = audio_input["input_features_mask"].squeeze(1)
+ audio_outputs, audio_mask = self.audio_tower(input_features,
+ ~input_features_mask)
+ audio_features = self.embed_audio(inputs_embeds=audio_outputs)
+
+ # ruff: noqa
+ # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
+ # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
+ # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
+ # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
+ # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
+ # TODO precompute and cache padding
+ audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
+ dtype=torch.long,
+ device=audio_features.device)
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
+ audio_features = torch.where(audio_mask.unsqueeze(-1),
+ audio_padding_embs, audio_features)
+
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
+ extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
+ extra_padding_features = audio_padding_embs.expand(
+ audio_batch_size, extra_padding_tokens, audio_embed_dim)
+
+ audio_features = torch.cat((audio_features, extra_padding_features),
+ dim=1)
+ # Return a list of embeddings instead of a batched tensor
+ return audio_features.unbind(0)
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def get_multimodal_embeddings(self,
+ **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
+ **kwargs)
+ if mm_input_by_modality is None:
+ return []
+
+ multimodal_embeddings: list[torch.Tensor] = []
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ vision_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings.extend(vision_embeddings)
+ if modality == "audio":
+ audio_embeddings = self._process_audio_input(multimodal_input)
+ multimodal_embeddings.extend(audio_embeddings)
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
+ # them here, as the model forward has only access to the input_embeds.
+ if input_ids is not None:
+ per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
+ input_ids)
+ per_layer_inputs = per_layer_inputs.reshape(
+ -1, self.config.text_config.num_hidden_layers,
+ self.config.text_config.hidden_size_per_layer_input)
+ self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
+ per_layer_inputs)
+
+ if multimodal_embeddings is not None \
+ and len(multimodal_embeddings) != 0:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ multimodal_embeddings,
+ # NOTE: this order of processing mm items is important
+ [self.config.image_token_id, self.config.audio_token_id])
+ return inputs_embeds
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object) -> IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE (NickLucche) During profiling, `get_input_embeddings` is not
+ # called, hence we don't have input_ids to compute PLEs. We simply
+ # select a chunk of pre-allocated PLEs. During normal execution,
+ # `get_input_embeddings` is called before forward, hence this slice
+ # will contain PLEs computed from the actual input_ids.
+ per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]]
+
+ hidden_states = self.language_model.model(
+ input_ids,
+ positions,
+ per_layer_inputs=per_layer_inputs,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **kwargs)
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="multi_modal_projector",
+ tower_model="vision_tower")
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality == "image":
+ return ""
+ elif modality == "audio":
+ return ""
+ else:
+ raise ValueError(f"Unsupported modality: {modality}")
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
index bd3e27662ee7c..131c042c3c2db 100644
--- a/vllm/model_executor/models/glm4_moe.py
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -28,7 +28,7 @@ from typing import Any, Optional, Union
import torch
from torch import nn
-from transformers import PretrainedConfig
+from transformers.models.glm4_moe import Glm4MoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -100,7 +100,7 @@ class Glm4MoE(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Glm4MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
@@ -198,7 +198,7 @@ class Glm4MoeAttention(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Glm4MoeConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
@@ -297,7 +297,7 @@ class Glm4MoeDecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Glm4MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -372,7 +372,13 @@ class Glm4MoeDecoderLayer(nn.Module):
return hidden_states, residual
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ })
class Glm4MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -601,8 +607,6 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
quant_config=quant_config)
else:
self.lm_head = PPMissingLayer()
- if self.config.tie_word_embeddings:
- self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@@ -683,7 +687,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return self.model.get_expert_mapping()
-def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
+def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index 661a67bdc0db0..036ded530f97d 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory, make_layers)
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
class GPTBigCodeAttention(nn.Module):
@@ -83,6 +84,7 @@ class GPTBigCodeAttention(nn.Module):
total_num_kv_heads,
bias=True,
quant_config=quant_config,
+ prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
@@ -90,6 +92,7 @@ class GPTBigCodeAttention(nn.Module):
self.hidden_size,
bias=True,
quant_config=quant_config,
+ prefix=f"{prefix}.c_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
@@ -123,6 +126,7 @@ class GPTBigMLP(nn.Module):
intermediate_size: int,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@@ -131,12 +135,14 @@ class GPTBigMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
+ prefix=f"{prefix}.c_fc",
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
+ prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.activation_function)
@@ -167,7 +173,10 @@ class GPTBigCodeBlock(nn.Module):
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.mlp = GPTBigMLP(inner_dim, config, quant_config)
+ self.mlp = GPTBigMLP(inner_dim,
+ config,
+ quant_config,
+ prefix=f"{prefix}.mlp")
def forward(
self,
@@ -260,7 +269,7 @@ class GPTBigCodeModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
- if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
+ if "c_attn.input_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
@@ -284,7 +293,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
- prefix=prefix)
+ prefix=maybe_prefix(
+ prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index feb323a04524b..6a65bbbe2e0db 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -160,7 +160,9 @@ class MLPBlock(torch.nn.Module):
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
- apply_router_weight_on_input=False)
+ apply_router_weight_on_input=False,
+ has_bias=True,
+ activation="swiglu_oai")
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
@@ -262,8 +264,8 @@ class GptOssForCausalLM(nn.Module):
sampling_metadata)
return logits
- def load_weights(self, weights: Iterable[tuple[str,
- torch.Tensor]]) -> set[str]:
+ def _load_weights_mxfp4(
+ self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
@@ -469,3 +471,147 @@ class GptOssForCausalLM(nn.Module):
loaded_params.add(renamed_name)
return loaded_params
+
+ def _load_weights_other(
+ self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ rename_mapping = {
+ "self_attn": "attn",
+ "input_layernorm.weight": "attn.norm.weight",
+ "post_attention_layernorm.weight": "mlp.norm.weight",
+ "embed_tokens": "embedding",
+ }
+
+ def maybe_rename(name: str) -> str:
+ for remap_name, new_name in rename_mapping.items():
+ if remap_name in name:
+ return name.replace(remap_name, new_name)
+ return name
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ tp_rank = get_tensor_model_parallel_rank()
+ tp_size = get_tensor_model_parallel_world_size()
+ intermediate_size = self.model_config.intermediate_size
+
+ per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
+ # Calculate common slicing bounds for current rank
+ tp_rank_start = tp_rank * per_rank_intermediate_size
+ tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
+ intermediate_size)
+
+ # Attention heads per rank
+ heads_per_rank = self.model_config.num_attention_heads // tp_size
+ head_start = tp_rank * heads_per_rank
+
+ use_ep = self.vllm_config.parallel_config.enable_expert_parallel
+ ep_size = get_ep_group().world_size
+ ep_rank = get_ep_group().rank
+ num_experts = self.model_config.num_local_experts
+ experts_per_rank = num_experts // ep_size
+ ep_rank_start = ep_rank * experts_per_rank
+ ep_rank_end = (ep_rank + 1) * experts_per_rank
+
+ for name, weight in weights:
+ if ".experts.gate_up_proj" in name and "bias" not in name:
+ # Handle MLP gate and up projection weights
+ new_name = name.replace(".experts.gate_up_proj",
+ ".experts.w13_weight")
+
+ # Extract gate and up projection parts
+ # since the weight is shuffled, we can slice directly
+ if use_ep:
+ narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
+ else:
+ narrow_weight = weight[:, :,
+ 2 * tp_rank_start:2 * tp_rank_end]
+
+ narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
+ param = params_dict[new_name]
+
+ param.copy_(narrow_weight)
+ loaded_params.add(new_name)
+
+ elif ".experts.down_proj" in name and "bias" not in name:
+ # Handle MLP down projection weights
+ new_name = name.replace(".experts.down_proj",
+ ".experts.w2_weight")
+
+ if use_ep:
+ narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
+ else:
+ narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
+ narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
+ param = params_dict[new_name]
+
+ param.copy_(narrow_weight)
+ loaded_params.add(new_name)
+
+ elif "gate_up_proj_bias" in name:
+ # Handle MLP gate and up projection biases
+ new_name = name.replace("gate_up_proj_bias", "w13_bias")
+
+ # Extract gate and up projection bias parts
+ if use_ep:
+ narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
+ else:
+ narrow_weight = weight[:,
+ 2 * tp_rank_start:2 * tp_rank_end]
+
+ param = params_dict[new_name]
+
+ param.copy_(narrow_weight)
+ loaded_params.add(new_name)
+
+ elif "down_proj_bias" in name:
+ # Handle MLP down projection bias
+ new_name = name.replace("down_proj_bias", "w2_bias")
+
+ if use_ep:
+ weight = weight[ep_rank_start:ep_rank_end, ...]
+ else:
+ # (only load on rank 0 to avoid duplication)
+ if tp_rank != 0:
+ weight.zero_()
+ param = params_dict[new_name]
+ param.copy_(weight)
+ loaded_params.add(new_name)
+ elif "sinks" in name:
+ # Handle attention sinks (distributed across ranks)
+ name = name.replace("self_attn", "attn")
+ param = params_dict[name]
+ narrow_weight = weight.narrow(0, head_start, heads_per_rank)
+ param.data.copy_(narrow_weight)
+ loaded_params.add(name)
+ elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
+ shard_id = ("q" if "q_proj" in name else
+ "k" if "k_proj" in name else "v")
+ name = name.replace("self_attn", "attn")
+ param_name = name.replace(f"{shard_id}_proj", "qkv")
+ param = params_dict[param_name]
+ weight_loader = param.weight_loader
+ weight_loader(param, weight, loaded_shard_id=shard_id)
+ loaded_params.add(param_name)
+ else:
+ # Handle all other weights with potential renaming
+
+ renamed_name = maybe_rename(name)
+ if renamed_name not in params_dict:
+ continue
+ param = params_dict[renamed_name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, weight)
+ loaded_params.add(renamed_name)
+
+ return loaded_params
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ quant_method = (self.model_config.quantization_config['quant_method']
+ if hasattr(self.model_config, "quantization_config")
+ else None)
+ if quant_method == "mxfp4":
+ return self._load_weights_mxfp4(weights)
+ else:
+ return self._load_weights_other(weights)
diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py
index c99970284a953..9e7490e3c4f07 100644
--- a/vllm/model_executor/models/gritlm.py
+++ b/vllm/model_executor/models/gritlm.py
@@ -248,9 +248,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
vllm_config.cache_config.sliding_window = None
- for attr in ("sliding_window", "interleaved_sliding_window"):
- if hasattr(hf_config, attr):
- delattr(hf_config, attr)
+ hf_config.sliding_window = None
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index b6d9877cd01b6..c425488f834b5 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -641,6 +641,20 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
+def default_pooling_type(pooling_type: str) -> object:
+ """Set default_pooling_type decorator. """
+
+ def func(model: object):
+ model.default_pooling_type = pooling_type
+ return model
+
+ return func
+
+
+def get_default_pooling_type(model: Union[type[object], object]) -> str:
+ return getattr(model, "default_pooling_type", "LAST")
+
+
class SupportsQuant:
"""The interface required for all models that support quantization."""
@@ -809,3 +823,56 @@ def supports_v0_only(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
return getattr(model, "supports_v0_only", False)
+
+
+@runtime_checkable
+class SupportsEagle3(Protocol):
+ """The interface required for models that support
+ EAGLE3 speculative decoding."""
+
+ supports_eagle3: ClassVar[Literal[True]] = True
+ """
+ A flag that indicates this model supports EAGLE3
+ speculative decoding.
+
+ Note:
+ There is no need to redefine this flag if this class is in the
+ MRO of your model class.
+ """
+
+ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
+ """
+ Set which layers should output auxiliary
+ hidden states for EAGLE3.
+
+ Args:
+ layers: Tuple of layer indices that should output auxiliary
+ hidden states.
+ """
+ ...
+
+ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
+ """
+ Get the layer indices that should output auxiliary hidden states
+ for EAGLE3.
+
+ Returns:
+ Tuple of layer indices for auxiliary hidden state outputs.
+ """
+ ...
+
+
+@overload
+def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]:
+ ...
+
+
+@overload
+def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]:
+ ...
+
+
+def supports_eagle3(
+ model: Union[type[object], object],
+) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
+ return isinstance(model, SupportsEagle3)
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index d29779a35e5c9..d0c4bf5450d6d 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsLoRA, SupportsPP
+from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loaded_params
+@default_pooling_type("ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py
index c1033aff07204..fbd310121ad47 100644
--- a/vllm/model_executor/models/jamba.py
+++ b/vllm/model_executor/models/jamba.py
@@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
-from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
- PoolingType)
+from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM):
Pooler.for_classify(
pooler_config,
classifier=self.score,
- default_pooling_type=PoolingType.LAST,
),
})
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 48ec611df12dd..24cd448d8361f 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsLoRA, SupportsPP
+from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@@ -167,18 +167,11 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling,
quant_config=quant_config)
- if hasattr(config, "interleaved_sliding_window"):
- interleaved_sliding_window = config.interleaved_sliding_window
- if isinstance(interleaved_sliding_window, int):
- sliding_window = interleaved_sliding_window
- elif isinstance(interleaved_sliding_window, list):
- sw_idx = layer_idx % len(interleaved_sliding_window)
- sliding_window = interleaved_sliding_window[sw_idx]
- else:
- raise ValueError(
- f"{type(interleaved_sliding_window)} is not supported.")
- else:
- sliding_window = None
+ sliding_window = None
+ if layer_types := getattr(config, "layer_types", None):
+ is_sliding = layer_types[layer_idx] == "sliding_attention"
+ if is_sliding:
+ sliding_window = config.sliding_window
self.attn = Attention(
self.num_heads,
@@ -470,7 +463,7 @@ class LlamaModel(nn.Module):
return loaded_params
-class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
+class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index c863ba406422d..89d2817b57e0e 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -3,7 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
-from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
+from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
import torch
@@ -16,7 +16,6 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
-from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -33,6 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
+from vllm.utils.jsontree import json_map_leaves
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -44,35 +45,46 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .vision import get_vision_encoder_info
-class LlavaImagePixelInputs(TypedDict):
- type: Literal["pixel_values"]
- pixel_values: torch.Tensor
+class LlavaImagePixelInputs(TensorSchema):
"""
- Shape: `(batch_size * num_images, num_channels, height, width)`
-
+ Dimensions:
+ - bn: Batch size * number of images
+ - c: Number of channels (3)
+ - h: Height
+ - w: Width
+
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
+ type: Literal["pixel_values"] = "pixel_values"
+ pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
-class PixtralHFImagePixelInputs(TypedDict):
- type: Literal["pixel_values_pixtral"]
- pixel_values: Union[torch.Tensor, list[torch.Tensor]]
+class PixtralHFImagePixelInputs(TensorSchema):
"""
- Shape: `(batch_size * num_images, num_channels, height, width)`
-
+ Dimensions:
+ - bn: Batch size * number of images
+ - c: Number of channels
+ - h: Height
+ - w: Width
+
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
+ type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
+ pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]],
+ TensorShape("bn", "c", "h", "w")]
-class LlavaImageEmbeddingInputs(TypedDict):
- type: Literal["image_embeds"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
-
- `hidden_size` must match the hidden size of language model backbone.
+class LlavaImageEmbeddingInputs(TensorSchema):
"""
+ Dimensions:
+ - bn: Batch size * number of images
+ - ifs: Image feature size
+ - hs: Hidden size (must match language model backbone)
+ """
+ type: Literal["image_embeds"] = "image_embeds"
+ data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
@@ -521,18 +533,22 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
- self.vision_tower = init_vision_tower_for_llava(
- config,
- quant_config,
- require_post_norm=False,
- prefix=maybe_prefix(prefix, "vision_tower"))
- self.multi_modal_projector = LlavaMultiModalProjector(
- vision_hidden_size=config.vision_config.hidden_size,
- text_hidden_size=config.text_config.hidden_size,
- projector_hidden_act=config.projector_hidden_act,
- multimodal_projector_bias=config.multimodal_projector_bias,
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ if multimodal_config.get_limit_per_prompt("image"):
+ self.vision_tower = init_vision_tower_for_llava(
+ config,
+ quant_config,
+ require_post_norm=False,
+ prefix=maybe_prefix(prefix, "vision_tower"))
+ self.multi_modal_projector = LlavaMultiModalProjector(
+ vision_hidden_size=config.vision_config.hidden_size,
+ text_hidden_size=config.text_config.hidden_size,
+ projector_hidden_act=config.projector_hidden_act,
+ multimodal_projector_bias=config.multimodal_projector_bias,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ else:
+ self.vision_tower = None
+ self.multi_modal_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -543,19 +559,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
- def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
- h = w = self.config.vision_config.image_size
- expected_dims = (3, h, w)
- actual_dims = tuple(data.shape[1:])
-
- if actual_dims != expected_dims:
- expected_expr = ("batch_size", *map(str, expected_dims))
- raise ValueError(
- f"The expected shape of pixel values is {expected_expr}. "
- f"You supplied {tuple(data.shape)}.")
-
- return data
-
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -575,10 +578,14 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values=flatten_bn(pixel_values),
)
+ expected_h = expected_w = self.config.vision_config.image_size
return LlavaImagePixelInputs(
type="pixel_values",
- pixel_values=self._validate_pixel_values(
- flatten_bn(pixel_values, concat=True)),
+ pixel_values=flatten_bn(pixel_values, concat=True),
+ resolve_bindings={
+ "h": expected_h,
+ "w": expected_w
+ },
)
if image_embeds is not None:
@@ -756,7 +763,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
+ skip_prefixes = []
+ if self.vision_tower is None and self.multi_modal_projector is None:
+ skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
+
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index 04fb6b5736a5c..a63c18493df5e 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -3,7 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping
-from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
+from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union)
import torch
@@ -11,7 +11,6 @@ import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
-from typing_extensions import NotRequired
from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
flatten_bn, init_vllm_registered_model, maybe_prefix)
-class LlavaNextImagePixelInputs(TypedDict):
- type: Literal["pixel_values"]
- pixel_values: Union[torch.Tensor, list[torch.Tensor]]
+class LlavaNextImagePixelInputs(TensorSchema):
"""
- Shape:
- `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
-
+ Dimensions:
+ - bn: Batch size * number of images
+ - np: Number of patches + 1
+ - c: Number of channels (3)
+ - h: Height
+ - w: Width
+
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
+ type: Literal["pixel_values"] = "pixel_values"
+ pixel_values: Annotated[
+ Union[torch.Tensor, list[torch.Tensor]],
+ TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
- image_sizes: NotRequired[torch.Tensor]
+ image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
+ # This should be in `(height, width)` format.
+
+
+class LlavaNextImageEmbeddingInputs(TensorSchema):
"""
- Shape: `(batch_size * num_images, 2)`
-
- This should be in `(height, width)` format.
- """
-
-
-class LlavaNextImageEmbeddingInputs(TypedDict):
- type: Literal["image_embeds"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
-
- `hidden_size` must match the hidden size of language model backbone.
+ Dimensions:
+ - bn: Batch size * number of images
+ - ifs: Image feature size
+ - hs: Hidden size (must match language model backbone)
"""
+ type: Literal["image_embeds"] = "image_embeds"
+ data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
@@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
- def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
- expected_dims = (2, )
-
- def _validate_shape(d: torch.Tensor):
- actual_dims = tuple(d.shape)
-
- if actual_dims != expected_dims:
- expected_expr = str(expected_dims)
- raise ValueError(
- f"The expected shape of image sizes per image per batch "
- f"is {expected_expr}. You supplied {tuple(d.shape)}.")
-
- for d in data:
- _validate_shape(d)
-
- return data
-
- def _validate_pixel_values(
- self, data: Union[torch.Tensor, list[torch.Tensor]]
- ) -> Union[torch.Tensor, list[torch.Tensor]]:
-
- h = w = self.config.vision_config.image_size
- expected_dims = (3, h, w)
-
- def _validate_shape(d: torch.Tensor):
- actual_dims = tuple(d.shape[1:])
-
- if actual_dims != expected_dims:
- expected_expr = ("num_patches", *map(str, expected_dims))
- raise ValueError(
- "The expected shape of pixel values per image per batch "
- f"is {expected_expr}. You supplied {tuple(d.shape)}.")
-
- for d in data:
- _validate_shape(d)
-
- return data
-
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
+ expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs(
type="pixel_values",
- pixel_values=self._validate_pixel_values(
- flatten_bn(pixel_values)),
- image_sizes=self._validate_image_sizes(
- flatten_bn(image_sizes, concat=True)),
- )
+ pixel_values=flatten_bn(pixel_values),
+ image_sizes=flatten_bn(image_sizes, concat=True),
+ resolve_bindings={
+ "h": expected_h,
+ "w": expected_w,
+ })
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py
index a96df0b6f572e..abc519edadcca 100644
--- a/vllm/model_executor/models/llava_next_video.py
+++ b/vllm/model_executor/models/llava_next_video.py
@@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
-from typing import Literal, Optional, TypedDict, Union
+from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
@@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
from .vision import get_vision_encoder_info
-class LlavaNextVideoPixelInputs(TypedDict):
- type: Literal["pixel_values_videos"]
- data: Union[torch.Tensor, list[torch.Tensor]]
- """
- Shape: `(batch_size, num_frames, num_channels, height, width)`
+class LlavaNextVideoPixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - bs: Batch size
+ - nv: Number of videos
+ - nf: Number of frames
+ - nc: Number of channels (3)
+ - h: Height of each frame
+ - w: Width of each frame
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
+ type: Literal["pixel_values_videos"] = "pixel_values_videos"
+
+ data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
+ TensorShape("bs", "nv", "nf", 3, "h", "w")]
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
@@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
- def _validate_video_pixel_values(
- self, data: Union[torch.Tensor, list[torch.Tensor]]
- ) -> Union[torch.Tensor, list[torch.Tensor]]:
-
- h = w = self.config.vision_config.image_size
- expected_dims = (3, h, w)
-
- def _validate_shape(d: torch.Tensor):
- actual_dims = tuple(d.shape[2:])
-
- if actual_dims != expected_dims:
- expected_expr = ("num_frames", *map(str, expected_dims))
- raise ValueError(
- "The expected shape of pixel values in each video frame "
- f"is {expected_expr}. You supplied {tuple(d.shape)}.")
-
- for d in data:
- _validate_shape(d)
-
- return data
-
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
"""
@@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values_videos is None:
return None
- if not isinstance(pixel_values_videos, (torch.Tensor, list)):
- raise ValueError("Incorrect type of pixel_values_videos. "
- f"Got type: {type(pixel_values_videos)}")
-
- return LlavaNextVideoPixelInputs(
- type="pixel_values_videos",
- data=pixel_values_videos,
- )
+ expected_h = expected_w = self.config.vision_config.image_size
+ return LlavaNextVideoPixelInputs(type="pixel_values_videos",
+ data=pixel_values_videos,
+ resolve_bindings={
+ "h": expected_h,
+ "w": expected_w,
+ })
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py
index 4e4fc3d5c7621..e1746695bd5db 100644
--- a/vllm/model_executor/models/minicpmo.py
+++ b/vllm/model_executor/models/minicpmo.py
@@ -587,15 +587,28 @@ class MiniCPMO(MiniCPMV2_6):
num_lookhead: int = 0,
) -> torch.Tensor:
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
- for i in range(size):
- if num_left_chunks < 0:
- start = 0
- else:
- start = max((i // chunk_size - num_left_chunks) * chunk_size,
- 0)
- ending = min((i // chunk_size + 1) * chunk_size + num_lookhead,
- size)
- ret[i, start:ending] = True
+ # Vectorized computation of row indices and chunk boundaries
+ row_indices = torch.arange(size, device=device)
+ chunk_indices = row_indices // chunk_size
+ if num_left_chunks < 0:
+ # If num_left_chunks < 0, start is always 0 for all rows
+ start_indices = torch.zeros_like(row_indices)
+ else:
+ # Compute start indices vectorially
+ start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks,
+ min=0)
+ start_indices = start_chunk_indices * chunk_size
+ # Compute ending indices vectorially
+ end_chunk_indices = chunk_indices + 1
+ end_indices = torch.clamp(end_chunk_indices * chunk_size +
+ num_lookhead,
+ max=size)
+ # Create column indices for broadcasting
+ col_indices = torch.arange(size, device=device).unsqueeze(0)
+ start_indices = start_indices.unsqueeze(1)
+ end_indices = end_indices.unsqueeze(1)
+ # Vectorized mask creation
+ ret = (col_indices >= start_indices) & (col_indices < end_indices)
return ret
def _get_feat_extract_output_lengths(self,
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 3aa16bb9abe45..7db3a1bb90b47 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -27,7 +27,7 @@ import math
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
-from typing import Any, Callable, Literal, Optional, TypedDict, Union
+from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np
import torch
@@ -63,6 +63,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@@ -74,36 +75,47 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
_MAX_FRAMES_PER_VIDEO = 16
-class MiniCPMVImagePixelInputs(TypedDict):
- type: Literal["pixel_values"]
- pixel_values: list[torch.Tensor]
+class MiniCPMVImagePixelInputs(TensorSchema):
"""
- Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
-
- Note that the image size may vary, so we pass it as a list
- instead of a batched tensor.
+ Dimensions:
+ - bns: Batch size * number of images * number of slices
+ - bn: Batch size * number of images
+ - c: Number of channels
+ - h: Height
+ - w: Width
"""
- tgt_sizes: torch.Tensor
- """
- Shape: `(batch_size * num_images * num_slices, 2)`
+ type: Literal["pixel_values"] = "pixel_values"
- This should be in `(height, width)` format.
+ # Note that the image size may vary, so we pass it as a list instead of a
+ # batched tensor.
+ pixel_values: Annotated[
+ list[torch.Tensor],
+ TensorShape("bns", "c", "h", "w"),
+ ]
+ tgt_sizes: Annotated[
+ torch.Tensor,
+ TensorShape("bns", 2), # This should be in `(height, width)` format.
+ ]
+ num_slices: Annotated[
+ torch.Tensor,
+ TensorShape("bn"),
+ ]
+
+
+class MiniCPMVImageEmbeddingInputs(TensorSchema):
+ """
+ Dimensions:
+ - bn: Batch size * number of images
+ - ns: Number of slices
+ - hs: Hidden size (must match language model backbone)
"""
- num_slices: torch.Tensor
- """Shape: `(batch_size * num_images)`"""
-
-
-class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
- image_embeds: Union[torch.Tensor, list[torch.Tensor]]
- """
- Shape: `(batch_size * num_images, num_slices, hidden_size)`
-
- `hidden_size` must match the hidden size of language model backbone.
- instead of a batched tensor.
- """
+ image_embeds: Annotated[
+ Union[torch.Tensor, list[torch.Tensor]],
+ TensorShape("bn", "ns", "hs"),
+ ]
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
@@ -832,11 +844,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
- if len(pixel_values_flat) != len(tgt_sizes_flat):
- raise ValueError("Inconsistent flattened lengths, found: "
- f"{len(pixel_values_flat)} vs. "
- f"{len(tgt_sizes_flat)}")
-
return MiniCPMVImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values_flat,
diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py
index f2773af490c53..3d14a6ad5c3a4 100644
--- a/vllm/model_executor/models/minimax_text_01.py
+++ b/vllm/model_executor/models/minimax_text_01.py
@@ -12,10 +12,11 @@ import torch.distributed
import torch.nn.functional as F
from einops import rearrange
from torch import nn
-from transformers.configuration_utils import PretrainedConfig
+from transformers import MiniMaxConfig
+from vllm import envs
from vllm.attention import Attention, AttentionMetadata
-from vllm.config import CacheConfig, VllmConfig
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
@@ -33,6 +34,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.mamba.abstract import MambaBase
+from vllm.model_executor.layers.mamba.mamba_utils import (
+ MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -41,8 +45,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
+from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
-from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
+from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
@@ -327,7 +332,17 @@ class MiniMaxText01LinearKernel:
return rearrange(output.squeeze(0), "h n d -> n (h d)")
-class MiniMaxText01LinearAttention(nn.Module):
+class MiniMaxText01LinearAttention(nn.Module, MambaBase):
+
+ @property
+ def mamba_type(self) -> str:
+ return "linear_attention"
+
+ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
+ return MambaStateShapeCalculator.linear_attention_state_shape(
+ num_heads=self.num_heads,
+ tp_size=self.tp_size,
+ head_dim=self.head_dim)
def __init__(
self,
@@ -359,6 +374,7 @@ class MiniMaxText01LinearAttention(nn.Module):
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
+ self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
hidden_size,
@@ -397,6 +413,12 @@ class MiniMaxText01LinearAttention(nn.Module):
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
+ if envs.VLLM_USE_V1:
+ compilation_config = get_current_vllm_config().compilation_config
+ if prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ compilation_config.static_forward_context[prefix] = self
+
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
@@ -434,13 +456,14 @@ class MiniMaxText01LinearAttention(nn.Module):
break
if _prefill_idx >= len(state_indices_tensor):
break
- _start = attn_metadata.query_start_loc[_prefill_idx]
- _end = attn_metadata.query_start_loc[_prefill_idx + 1]
- slot_id = state_indices_tensor[_prefill_idx]
+ # prefills are packed at end of batch in V1
+ offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
+ _start = attn_metadata.query_start_loc[offset + _prefill_idx]
+ _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
+ slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
- slot_id = state_indices_tensor[_prefill_idx]
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
@@ -453,9 +476,13 @@ class MiniMaxText01LinearAttention(nn.Module):
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
- hidden.append(
- self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
- attn_metadata))
+ hidden_decode = self._decode_infer(q, k, v, kv_cache,
+ state_indices_tensor,
+ attn_metadata)
+ if envs.VLLM_USE_V1:
+ hidden.insert(0, hidden_decode)
+ else:
+ hidden.append(hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
@@ -465,11 +492,17 @@ class MiniMaxText01LinearAttention(nn.Module):
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
- q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
- k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
- v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
- slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
- ):]
+ if not envs.VLLM_USE_V1:
+ q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
+ k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
+ v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
+ num_prefills = getattr(attn_metadata, "num_prefills", 0)
+ slot_id = state_indices_tensor[num_prefills:]
+ else:
+ q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
+ k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
+ v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
+ slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
@@ -483,17 +516,49 @@ class MiniMaxText01LinearAttention(nn.Module):
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
- kv_cache = kv_caches.minimax_cache
- state_indices_tensor = kv_caches.state_indices_tensor
+ if envs.VLLM_USE_V1:
+ if attn_metadata is not None:
+ assert isinstance(attn_metadata, dict)
+ attn_metadata = attn_metadata[self.prefix]
+ assert isinstance(attn_metadata, LinearAttentionMetadata)
+ kv_cache = self.kv_cache[forward_context.virtual_engine][0]
+ state_indices_tensor = attn_metadata.state_indices_tensor
+
+ num_prefills = getattr(attn_metadata, "num_prefills", 0)
+ if num_prefills > 0:
+ num_decode_tokens = getattr(attn_metadata,
+ "num_decode_tokens", 0)
+ for prefill_idx in range(num_prefills):
+ q_start = attn_metadata.query_start_loc[
+ num_decode_tokens + prefill_idx]
+ q_end = attn_metadata.query_start_loc[num_decode_tokens
+ + prefill_idx +
+ 1]
+ query_len = q_end - q_start
+ context_len = attn_metadata.seq_lens[
+ num_decode_tokens + prefill_idx] - query_len
+ if context_len == 0:
+ block_to_clear = state_indices_tensor[
+ num_decode_tokens + prefill_idx]
+ kv_cache[block_to_clear, ...] = 0
+ else:
+ kv_cache = kv_caches.minimax_cache
+ state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
- if not decode_only:
- hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
- state_indices_tensor,
- attn_metadata)
+ if attn_metadata is None:
+ hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
+ device=q.device,
+ dtype=q.dtype)
else:
- hidden = self._decode_infer(q, k, v, kv_cache,
- state_indices_tensor, attn_metadata)
+ if not decode_only:
+ hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
+ state_indices_tensor,
+ attn_metadata)
+ else:
+ hidden = self._decode_infer(q, k, v, kv_cache,
+ state_indices_tensor,
+ attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
@@ -541,6 +606,7 @@ class MiniMaxText01Attention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
+ self.prefix = prefix
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -575,7 +641,12 @@ class MiniMaxText01Attention(nn.Module):
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- q, k = attn_metadata.rotary_emb(positions, q, k)
+ if envs.VLLM_USE_V1:
+ if attn_metadata is not None:
+ q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
+ positions, q, k)
+ else:
+ q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -585,7 +656,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: MiniMaxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
expert_num: int = 1,
@@ -595,6 +666,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
) -> None:
self._ilayer = layer_id
self._irank = get_tensor_model_parallel_rank()
+ self.prefix = prefix
super().__init__()
self.hidden_size = config.hidden_size
@@ -788,7 +860,7 @@ class MiniMaxText01Model(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: MiniMaxConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
@@ -876,8 +948,9 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype
del _dummy
- self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
- cache_shape=self.cache_shape)
+ if not envs.VLLM_USE_V1:
+ self.minimax_cache = MinimaxCacheManager(
+ dtype=torch.float32, cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim", None)
@@ -944,23 +1017,27 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
- if attn_metadata is None:
+ if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
- (
- minimax_cache_tensors,
- state_indices_tensor,
- ) = self.minimax_cache.current_run_tensors(**kwargs)
- if getattr(attn_metadata, "num_prefills", 0) > 0:
- self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
- **kwargs)
+ if not envs.VLLM_USE_V1:
+ (
+ minimax_cache_tensors,
+ state_indices_tensor,
+ ) = self.minimax_cache.current_run_tensors(**kwargs)
+ if getattr(attn_metadata, "num_prefills", 0) > 0:
+ self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
+ **kwargs)
+
+ minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
+ state_indices_tensor)
+ else:
+ minimax_cache_params = None
- minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
- state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
@@ -973,11 +1050,22 @@ class MiniMaxText01Model(nn.Module):
residual = intermediate_tensors["residual"]
minimax_cache_index = 0
- attn_metadata.rotary_emb = self.rotary_emb
+
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
+ if attn_metadata is not None:
+ # TODO (tdoublep): this whole thing with the rotary_emb is
+ # weird. we shouldn't be passing it via attn_metadata imo.
+ if envs.VLLM_USE_V1:
+ if isinstance(layer.self_attn, MiniMaxText01Attention):
+ attn_metadata[layer.prefix +
+ ".attn"].rotary_emb = self.rotary_emb
+ else:
+ attn_metadata.rotary_emb = self.rotary_emb
+
_caches = None
- if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
+ if not envs.VLLM_USE_V1 and isinstance(
+ layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
@@ -1002,8 +1090,7 @@ class MiniMaxText01Model(nn.Module):
return hidden_states
-class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
- SupportsV0Only):
+class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
@@ -1321,3 +1408,28 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
load_basic_weight(name, loaded_weight, self)
return loaded_params
+
+ @classmethod
+ def get_mamba_state_shape_from_config(
+ cls,
+ vllm_config: "VllmConfig",
+ use_v1: bool = True,
+ ) -> tuple[tuple[int, ...], ...]:
+ """Calculate shape for MiniMaxText01LinearAttention cache.
+
+ Args:
+ vllm_config: vLLM config
+ use_v1: Get shapes for V1 (or V0)
+
+ Returns:
+ Tuple containing:
+ - state_shape: Shape of the cache
+ """
+ parallel_config = vllm_config.parallel_config
+ hf_config = vllm_config.model_config.hf_config
+
+ return MambaStateShapeCalculator.linear_attention_state_shape(
+ num_heads=hf_config.num_attention_heads,
+ tp_size=parallel_config.tensor_parallel_size,
+ head_dim=hf_config.head_dim,
+ )
diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py
index 62a7d37ec9d33..8107c6e8a04a1 100644
--- a/vllm/model_executor/models/minimax_vl_01.py
+++ b/vllm/model_executor/models/minimax_vl_01.py
@@ -8,7 +8,6 @@ import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
-from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -17,6 +16,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
+from vllm.utils.jsontree import json_map_leaves
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py
index 88c3823eaa193..9e29a96c6e44a 100644
--- a/vllm/model_executor/models/mistral3.py
+++ b/vllm/model_executor/models/mistral3.py
@@ -428,20 +428,24 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
- self.vision_tower = init_vision_tower_for_llava(
- config,
- quant_config,
- require_post_norm=False,
- prefix=maybe_prefix(prefix, "vision_tower"))
- self.multi_modal_projector = Mistral3MultiModalProjector(
- vision_hidden_size=config.vision_config.hidden_size,
- text_hidden_size=config.text_config.hidden_size,
- projector_hidden_act=config.projector_hidden_act,
- spatial_merge_size=config.spatial_merge_size,
- patch_size=config.vision_config.patch_size,
- multimodal_projector_bias=config.multimodal_projector_bias,
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ if multimodal_config.get_limit_per_prompt("image"):
+ self.vision_tower = init_vision_tower_for_llava(
+ config,
+ quant_config,
+ require_post_norm=False,
+ prefix=maybe_prefix(prefix, "vision_tower"))
+ self.multi_modal_projector = Mistral3MultiModalProjector(
+ vision_hidden_size=config.vision_config.hidden_size,
+ text_hidden_size=config.text_config.hidden_size,
+ projector_hidden_act=config.projector_hidden_act,
+ spatial_merge_size=config.spatial_merge_size,
+ patch_size=config.vision_config.patch_size,
+ multimodal_projector_bias=config.multimodal_projector_bias,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ else:
+ self.vision_tower = None
+ self.multi_modal_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -611,7 +615,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
+ skip_prefixes = []
+ if self.vision_tower is None and self.multi_modal_projector is None:
+ skip_prefixes = ["vision_tower.", "multi_modal_projector."]
+
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py
index e73dc0c2be82e..b405dfca6d39b 100644
--- a/vllm/model_executor/models/mllama4.py
+++ b/vllm/model_executor/models/mllama4.py
@@ -737,16 +737,20 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
- self.vision_model = Llama4VisionModel(
- config.vision_config,
- None,
- prefix=maybe_prefix(prefix, "vision_model"),
- use_data_parallel=self.use_data_parallel,
- )
- self.multi_modal_projector = Llama4MultiModalProjector(
- self.config,
- None,
- prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ if multimodal_config.get_limit_per_prompt("image"):
+ self.vision_model = Llama4VisionModel(
+ config.vision_config,
+ None,
+ prefix=maybe_prefix(prefix, "vision_model"),
+ use_data_parallel=self.use_data_parallel,
+ )
+ self.multi_modal_projector = Llama4MultiModalProjector(
+ self.config,
+ None,
+ prefix=maybe_prefix(prefix, "multi_modal_projector"))
+ else:
+ self.vision_model = None
+ self.multi_modal_projector = None
self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config,
["LlamaForCausalLM"]),
@@ -783,6 +787,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
+
+ assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
@@ -1048,6 +1054,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model_weights, other_weights = (
self._separate_and_rename_weights(weights))
+ # Skip loading vision model and projector if they're not initialized.
+ if self.vision_model is None and self.multi_modal_projector is None:
+ other_weights = []
+
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights,
diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py
index 4967032a244ec..c6e84e2d4e040 100644
--- a/vllm/model_executor/models/modernbert.py
+++ b/vllm/model_executor/models/modernbert.py
@@ -8,6 +8,7 @@ from torch import nn
from transformers import ModernBertConfig
from vllm.attention import Attention, AttentionType
+from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -25,7 +26,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
-from .interfaces import SupportsCrossEncoding, SupportsV0Only
+from .interfaces import SupportsCrossEncoding, default_pooling_type
from .utils import WeightsMapper, maybe_prefix
@@ -46,7 +47,7 @@ class ModernBertEmbeddings(nn.Module):
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- if inputs_embeds:
+ if inputs_embeds is not None:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids)
@@ -91,16 +92,14 @@ class ModernBertAttention(nn.Module):
bias=config.attention_bias,
)
+ sliding_window = None
if layer_id % config.global_attn_every_n_layers != 0:
- self.local_attention = (config.local_attention // 2,
- config.local_attention // 2)
+ sliding_window = config.local_attention // 2
+ rope_theta = config.local_rope_theta if config.local_rope_theta \
+ is not None else config.global_rope_theta
else:
- self.local_attention = (-1, -1)
+ rope_theta = config.global_rope_theta
- rope_theta = config.global_rope_theta
- if self.local_attention != (
- -1, -1) and config.local_rope_theta is not None:
- rope_theta = config.local_rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
head_size=self.head_dim,
dim=self.head_dim,
@@ -109,7 +108,8 @@ class ModernBertAttention(nn.Module):
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
- attn_type=AttentionType.ENCODER_ONLY)
+ attn_type=AttentionType.ENCODER_ONLY,
+ per_layer_sliding_window=sliding_window)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)
@@ -117,7 +117,7 @@ class ModernBertAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
- position_ids: Optional[torch.LongTensor] = None,
+ position_ids: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
@@ -169,9 +169,9 @@ class ModernBertLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
- position_ids: Optional[torch.LongTensor] = None,
- ):
- attn_outputs = self.attn(self.attn_norm(hidden_states),
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states),
position_ids=position_ids)
hidden_states = hidden_states + attn_outputs
mlp_output = self.mlp(self.mlp_norm(hidden_states))
@@ -192,13 +192,15 @@ class ModernBertEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
- position_ids: Optional[torch.LongTensor] = None,
+ position_ids: torch.Tensor,
) -> torch.Tensor:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, position_ids)
return hidden_states
+@support_torch_compile
+@default_pooling_type("CLS")
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
@@ -234,13 +236,11 @@ class ModernBertModel(nn.Module):
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- positions: Optional[torch.Tensor] = None,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
- position_ids = positions if positions is not None else position_ids
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
@@ -249,7 +249,7 @@ class ModernBertModel(nn.Module):
outputs = self.encoder_layer(
hidden_states=hidden_states,
- position_ids=position_ids,
+ position_ids=positions,
)
norm_outputs = self.final_norm(outputs)
return norm_outputs
@@ -264,7 +264,6 @@ class ModernBertPooler(Pooler):
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
- self.pooling_type = config.classifier_pooling
self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
@@ -277,6 +276,7 @@ class ModernBertPooler(Pooler):
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
+ pooled_output = pooled_output.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_output)))
def forward(
@@ -294,8 +294,8 @@ class ModernBertPooler(Pooler):
return pooled_output
-class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
- SupportsCrossEncoding):
+@default_pooling_type("CLS")
+class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
@@ -306,6 +306,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@@ -315,14 +316,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
- pooling=ModernBertPooler(config),
+ pooling=self.pooling,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
- pooling=ModernBertPooler(config),
+ pooling=self.pooling,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
@@ -351,7 +352,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
default_weight_loader)
weight_loader(param, loaded_weight)
if name.startswith("head"):
- param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
+ param = params_dict["pooling." + name[len("head") + 1:]]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
@@ -366,5 +367,5 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
return self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
- position_ids=positions,
+ positions=positions,
)
diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py
index eb62d5a53c1a3..08315a13853c0 100644
--- a/vllm/model_executor/models/nemotron_h.py
+++ b/vllm/model_executor/models/nemotron_h.py
@@ -64,20 +64,32 @@ class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
+ layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
+
+ hybrid_override_pattern = config.hybrid_override_pattern
+ mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1
+ if isinstance(config.intermediate_size, list):
+ if len(config.intermediate_size) == 1:
+ intermediate_size = config.intermediate_size[0]
+ else:
+ intermediate_size = config.intermediate_size[mlp_index]
+ else:
+ intermediate_size = config.intermediate_size
+
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
- output_size=config.intermediate_size,
+ output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
- input_size=config.intermediate_size,
+ input_size=intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
@@ -110,6 +122,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
+ layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -146,7 +159,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
ssm_state_size=config.ssm_state_size,
conv_kernel_size=config.conv_kernel,
- intermediate_size=config.expand * config.hidden_size,
+ intermediate_size=config.mamba_num_heads * config.mamba_head_dim,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.n_groups,
@@ -205,7 +218,10 @@ class NemotronHAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = config.hidden_size // self.total_num_heads
+ if hasattr(config, "head_dim") and config.head_dim is not None:
+ self.head_dim = config.head_dim
+ else:
+ self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@@ -481,7 +497,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
- intermediate_size = hf_config.expand * hf_config.hidden_size
+ intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py
index 7552f64c423ea..a47c3bd416459 100644
--- a/vllm/model_executor/models/olmoe.py
+++ b/vllm/model_executor/models/olmoe.py
@@ -19,7 +19,7 @@ from typing import Any, Optional, Union
import torch
from torch import nn
-from transformers import PretrainedConfig
+from transformers import OlmoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -205,7 +205,7 @@ class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: OlmoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py
index 1a761d01fc066..493a4192d35ad 100644
--- a/vllm/model_executor/models/phi4flash.py
+++ b/vllm/model_executor/models/phi4flash.py
@@ -116,13 +116,8 @@ class SambaYAttention(nn.Module):
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
- sliding_window = config.interleaved_sliding_window[layer_idx]
- if layer_idx >= config.num_hidden_layers // 2:
- assert sliding_window is None, \
- "sliding_window must be none for the second decoder"
- else:
- assert sliding_window is not None, \
- "sliding_window must be set for the first decoder"
+ is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ sliding_window = config.sliding_window if is_sliding else None
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py
index 304a9e987ee03..20f423cc7603d 100644
--- a/vllm/model_executor/models/prithvi_geospatial_mae.py
+++ b/vllm/model_executor/models/prithvi_geospatial_mae.py
@@ -25,11 +25,11 @@ import torch.nn as nn
from transformers import BatchFeature
from vllm.config import VllmConfig
-from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
- PoolerIdentity, SimplePooler)
+from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
- IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
+ IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput,
+ default_pooling_type)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@@ -142,6 +142,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
)
+@default_pooling_type("All")
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
@@ -198,7 +199,11 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
- self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity()))
+ pooler_config = vllm_config.model_config.pooler_config
+ assert pooler_config is not None
+
+ self.pooler = DispatchPooler(
+ {"encode": Pooler.for_encode(pooler_config)}, )
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 0e7507a4570be..7304fbf120ccd 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import is_interleaved
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@@ -285,8 +286,7 @@ class Qwen2Model(nn.Module):
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
- if (cache_config.sliding_window is not None
- and hasattr(config, "max_window_layers")):
+ if is_interleaved(vllm_config.model_config.hf_text_config):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
@@ -408,9 +408,18 @@ class Qwen2Model(nn.Module):
continue
if is_pp_missing_parameter(name, self):
continue
+ if name.endswith("scale"):
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ if weight_loader == default_weight_loader:
+ weight_loader(param, loaded_weight)
+ else:
+ weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py
index a3af541d20676..e95295c31885a 100644
--- a/vllm/model_executor/models/qwen2_5_omni_thinker.py
+++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py
@@ -722,13 +722,24 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"exactly same result as the transformers implementation "
"in the audio tower part.")
- self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
- self.visual = Qwen2_5_VisionTransformer(
- vision_config=thinker_config.vision_config,
- norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "visual"),
- )
+ if multimodal_config.get_limit_per_prompt("audio"):
+ self.audio_tower = Qwen2_5OmniAudioEncoder(
+ thinker_config.audio_config)
+ else:
+ self.audio_tower = None
+
+ if multimodal_config.get_limit_per_prompt(
+ "image") or multimodal_config.get_limit_per_prompt("video"):
+ self.visual = Qwen2_5_VisionTransformer(
+ vision_config=thinker_config.vision_config,
+ norm_eps=getattr(thinker_config.text_config, "rms_norm_eps",
+ 1e-6),
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+ else:
+ self.visual = None
+
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -886,9 +897,15 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
+ skip_prefixes = ["talker.", "token2wav."]
+ if self.audio_tower is None:
+ skip_prefixes.extend(["audio_tower."])
+ if self.visual is None:
+ skip_prefixes.extend(["visual."])
+
loader = AutoWeightsLoader(
self,
- skip_prefixes=["talker.", "token2wav."],
+ skip_prefixes=skip_prefixes,
)
loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 79c5c77f6de69..6bea180ffec90 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -843,12 +843,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
- self.visual = Qwen2_5_VisionTransformer(
- config.vision_config,
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
- quant_config=self._maybe_ignore_quant_config(self.quant_config),
- prefix=maybe_prefix(prefix, "visual"),
- )
+ if multimodal_config.get_limit_per_prompt("image") or \
+ multimodal_config.get_limit_per_prompt("video"):
+ self.visual = Qwen2_5_VisionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self._maybe_ignore_quant_config(
+ self.quant_config),
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+ else:
+ self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -1152,7 +1157,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
+ skip_prefixes = []
+ if self.visual is None:
+ skip_prefixes.extend(["visual."])
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index b061e2f69a6c6..5c4ad34246d66 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -30,7 +30,7 @@ from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
-from transformers import PretrainedConfig
+from transformers import Qwen2MoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -98,7 +98,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Qwen2MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -256,7 +256,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Qwen2MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py
index 9b6b70c75c341..e0a30e04c602a 100644
--- a/vllm/model_executor/models/qwen2_rm.py
+++ b/vllm/model_executor/models/qwen2_rm.py
@@ -15,11 +15,10 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
-from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
- PoolingType)
+from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsLoRA, SupportsPP
+from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix
@@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights)
+@default_pooling_type("ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
{"encode": Pooler.for_encode(pooler_config)}, )
+@default_pooling_type("STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
- self.pooler = DispatchPooler({
- "encode":
- Pooler.for_encode(
- pooler_config,
- default_pooling_type=PoolingType.STEP,
- )
- })
+ self.pooler = DispatchPooler(
+ {"encode": Pooler.for_encode(pooler_config)})
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 633f8598e879d..f2d438b3850b8 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -1049,12 +1049,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
- self.visual = Qwen2VisionTransformer(
- config.vision_config,
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
- quant_config=self._maybe_ignore_quant_config(quant_config),
- prefix=maybe_prefix(prefix, "visual"),
- )
+ if multimodal_config.get_limit_per_prompt("image") or \
+ multimodal_config.get_limit_per_prompt("video"):
+ self.visual = Qwen2VisionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self._maybe_ignore_quant_config(quant_config),
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+ else:
+ self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -1350,7 +1354,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
+ skip_prefixes = []
+ if self.visual is None:
+ skip_prefixes.extend(["visual."])
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
@@ -1445,5 +1452,8 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
+ skip_prefixes = []
+ if self.visual is None:
+ skip_prefixes.extend(["visual."])
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py
index 0ad50640bb3bc..2060206633702 100644
--- a/vllm/model_executor/models/qwen3.py
+++ b/vllm/model_executor/models/qwen3.py
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsLoRA, SupportsPP
+from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@@ -261,7 +261,7 @@ class Qwen3Model(Qwen2Model):
decoder_layer_type=Qwen3DecoderLayer)
-class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
+class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index 7410589190bac..085fc90b47b53 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -28,7 +28,7 @@ from typing import Any, Optional, Union
import torch
from torch import nn
-from transformers import PretrainedConfig
+from transformers import Qwen3MoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
@@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
@@ -100,7 +101,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
@@ -148,7 +149,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -277,7 +278,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
- config: PretrainedConfig,
+ config: Qwen3MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -471,12 +472,21 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
+ if name.endswith("scale"):
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
if name not in params_dict:
continue
param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ if weight_loader == default_weight_loader:
+ weight_loader(param, loaded_weight)
+ else:
+ weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
@@ -674,4 +684,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- return self.model.get_expert_mapping()
+ return self.model.get_expert_mapping()
\ No newline at end of file
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index c746e8ec3f294..64dbde4916a26 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -25,8 +25,8 @@ from vllm.logger import init_logger
from vllm.transformers_utils.dynamic_module import (
try_get_class_from_dynamic_module)
-from .interfaces import (has_inner_state, has_noops, is_attention_free,
- is_hybrid, supports_cross_encoding,
+from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
+ is_attention_free, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_pooling_model, is_text_generation_model
@@ -69,8 +69,7 @@ _TEXT_GENERATION_MODELS = {
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
- #TODO(ywang96): Support multimodal gemma3n
- "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
+ "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
@@ -202,12 +201,14 @@ _MULTIMODAL_MODELS = {
"AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
+ "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
+ "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
- "Glm4v_moeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
+ "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
@@ -259,6 +260,8 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
+ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
+ # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
@@ -268,6 +271,9 @@ _SPECULATIVE_DECODING_MODELS = {
}
_TRANSFORMERS_SUPPORTED_MODELS = {
+ # Text generation models
+ "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
+ # Multimodal models
"Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
}
@@ -304,6 +310,7 @@ class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_pooling_model: bool
+ default_pooling_type: str
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
@@ -322,6 +329,7 @@ class _ModelInfo:
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model),
+ default_pooling_type=get_default_pooling_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py
index 77e072c792755..32a4a2c9a2694 100644
--- a/vllm/model_executor/models/roberta.py
+++ b/vllm/model_executor/models/roberta.py
@@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
+from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
+ BertEmbeddingModel, BertModel,
+ _decode_token_type_ids,
+ _encode_token_type_ids)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel
-from .interfaces import SupportsCrossEncoding, SupportsV0Only
+from .interfaces import SupportsCrossEncoding, default_pooling_type
class RobertaEmbedding(nn.Module):
@@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
- token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- input_shape = input_ids.size()
- inputs_embeds = self.word_embeddings(input_ids)
- # Position embeddings.
+ token_type_ids = _decode_token_type_ids(input_ids)
+
+ inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape,
- dtype=torch.long,
- device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
@@ -88,6 +86,7 @@ class RobertaClassificationHead(nn.Module):
return x
+@default_pooling_type("CLS")
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
@@ -105,9 +104,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def forward(
self,
- input_ids: Optional[torch.Tensor],
+ input_ids: torch.Tensor,
positions: torch.Tensor,
- token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@@ -120,8 +118,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
padding_idx=self.padding_idx)
return self.model(input_ids=input_ids,
- position_ids=positions,
- token_type_ids=token_type_ids,
+ positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
@@ -153,8 +150,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
-class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
- SupportsV0Only):
+@default_pooling_type("CLS")
+class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@@ -226,11 +223,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
+ if token_type_ids is not None:
+ assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
+ assert input_ids is not None
+ _encode_token_type_ids(input_ids, token_type_ids)
return self.roberta(input_ids=input_ids,
- position_ids=positions,
+ positions=positions,
inputs_embeds=inputs_embeds,
- intermediate_tensors=intermediate_tensors,
- token_type_ids=token_type_ids)
+ intermediate_tensors=intermediate_tensors)
# Adapted from transformers
diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py
index 363c12a4bf2b8..41dba312cb422 100644
--- a/vllm/model_executor/models/step3_vl.py
+++ b/vllm/model_executor/models/step3_vl.py
@@ -837,27 +837,35 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
- self.vision_model = Step3VisionTransformer(config.vision_config,
- None,
- prefix=maybe_prefix(
- prefix, "vision_model"))
- self.vit_downsampler = nn.Conv2d(
- config.vision_config.hidden_size,
- config.vision_config.output_hidden_size,
- kernel_size=2,
- stride=config.understand_projector_stride)
- self.vit_downsampler2 = nn.Conv2d(
- config.vision_config.output_hidden_size,
- config.vision_config.output_hidden_size * 2,
- kernel_size=3,
- stride=2,
- padding=1,
- )
- self.vit_large_projector = nn.Linear(
- config.vision_config.output_hidden_size * 2,
- config.hidden_size,
- bias=config.projector_bias,
- )
+ if multimodal_config.get_limit_per_prompt("image"):
+ self.vision_model = Step3VisionTransformer(config.vision_config,
+ None,
+ prefix=maybe_prefix(
+ prefix,
+ "vision_model"))
+ self.vit_downsampler = nn.Conv2d(
+ config.vision_config.hidden_size,
+ config.vision_config.output_hidden_size,
+ kernel_size=2,
+ stride=config.understand_projector_stride)
+ self.vit_downsampler2 = nn.Conv2d(
+ config.vision_config.output_hidden_size,
+ config.vision_config.output_hidden_size * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ self.vit_large_projector = nn.Linear(
+ config.vision_config.output_hidden_size * 2,
+ config.hidden_size,
+ bias=config.projector_bias,
+ )
+ else:
+ self.vision_model = None
+ self.vit_downsampler = None
+ self.vit_downsampler2 = None
+ self.vit_large_projector = None
+
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
@@ -1046,7 +1054,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
- loader = AutoWeightsLoader(self)
+
+ skip_prefixes = []
+ if self.vision_model is None and self.vit_large_projector is None:
+ skip_prefixes = [
+ "vision_model.", "vit_downsampler.", "vit_downsampler2.",
+ "vit_large_projector."
+ ]
+
+ loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper)
return loaded_weights
diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py
index 70cf5e95a54e1..c8709d866b1e7 100644
--- a/vllm/model_executor/models/tarsier.py
+++ b/vllm/model_executor/models/tarsier.py
@@ -18,7 +18,6 @@ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
-from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -34,6 +33,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
+from vllm.utils.jsontree import json_map_leaves
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py
index 92e132045c278..25b8b69e081b6 100644
--- a/vllm/model_executor/models/transformers.py
+++ b/vllm/model_executor/models/transformers.py
@@ -16,7 +16,7 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
from collections.abc import Iterable, Mapping
-from contextlib import contextmanager, nullcontext
+from contextlib import contextmanager
from typing import Literal, Optional, Union
import regex as re
@@ -107,10 +107,17 @@ def replace_linear_class(
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")
- vllm_linear_cls = {
- "colwise": ColumnParallelLinear,
- "rowwise": RowParallelLinear,
- }.get(style, ReplicatedLinear)
+ vllm_linear_cls, vllm_linear_kwargs = {
+ "colwise": (ColumnParallelLinear, {}),
+ "colwise_rep": (ColumnParallelLinear, {
+ "gather_output": True
+ }),
+ "rowwise": (RowParallelLinear, {}),
+ "rowwise_rep": (RowParallelLinear, {
+ "input_is_parallel": False
+ }),
+ "replicate": (ReplicatedLinear, {}),
+ }.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
@@ -118,6 +125,7 @@ def replace_linear_class(
bias=linear.bias is not None,
quant_config=quant_config,
return_bias=False,
+ **vllm_linear_kwargs,
)
@@ -382,33 +390,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
)
-class ConfigOverride:
- """Context manager to temporarily override config attributes."""
-
- def __init__(self, config: PretrainedConfig, **kwargs):
- self.config = config
- self.kwargs = kwargs
- self.kwargs_original = {}
- self.kwargs_delete = set()
-
- def __enter__(self):
- """Override config attributes."""
- for key, value in self.kwargs.items():
- if not hasattr(self.config, key):
- self.kwargs_delete.add(key)
- self.kwargs_original[key] = getattr(self.config, key, None)
- setattr(self.config, key, value)
- return self.config
-
- def __exit__(self, exc_type, exc_value, traceback):
- """Restore original config attributes."""
- for key, value in self.kwargs_original.items():
- if key in self.kwargs_delete:
- delattr(self.config, key)
- else:
- setattr(self.config, key, value)
-
-
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
@@ -434,21 +415,11 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None
- # vLLM handles interleaved sliding window attention by creating a new
- # interleaved_sliding_window attribute and deleting the sliding_window
- # attribute. This breaks the constructors in Transformers so we
- # temporarily add the attribute back to construct the model.
- config_override = nullcontext()
- if hasattr(self.config, "interleaved_sliding_window"):
- config_override = ConfigOverride(
- self.config,
- sliding_window=self.config.interleaved_sliding_window)
-
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm"
- with init_on_device_without_buffers("meta"), config_override:
+ with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
@@ -543,7 +514,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
- tp_plan[".*"] = "replicated"
+ tp_plan[".*"] = "replicate"
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
@@ -575,11 +546,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
- sliding_window = None
- if (hasattr(self.config, "interleaved_sliding_window")
- and hasattr(self.config, "sliding_window_pattern")
- and ((i + 1) % self.config.sliding_window_pattern > 0)):
- sliding_window = self.config.interleaved_sliding_window
+ per_layer_sliding_window = None
+ if (hasattr(self.config, "layer_types")
+ and self.config.layer_types[i] == "sliding_attention"):
+ per_layer_sliding_window = self.config.sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
@@ -590,7 +560,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
- per_layer_sliding_window=sliding_window,
+ per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn")
return attention_instances
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index fecd14dde4a81..6c27fedc61b17 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -401,7 +401,7 @@ def merge_multimodal_embeddings_from_map(
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
- placeholder_map.src]
+ placeholder_map.src].to(dtype=inputs_embeds.dtype)
return inputs_embeds
@@ -421,7 +421,8 @@ def _merge_multimodal_embeddings(
flattened = _flatten_embeddings(multimodal_embeddings)
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
- inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
+ inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
+ flattened.to(dtype=inputs_embeds.dtype))
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
@@ -735,7 +736,23 @@ def cast_overflow_tensors(
return tensors
-def fast_topk(values, topk, dim):
+def fast_topk(values: torch.Tensor, topk: int,
+ dim: int) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Optimized topk implementation that uses torch.max for k=1 case.
+
+ This function provides better performance for the common case of k=1
+ by using torch.max instead of the more general torch.topk.
+
+ Args:
+ values: Input tensor to find top-k values from
+ topk: Number of top values to return (k). Must be > 0.
+ dim: Dimension along which to compute topk
+
+ Returns:
+ Tuple of (values, indices) where values are the top-k values
+ and indices are their corresponding indices in the input tensor
+ """
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
diff --git a/vllm/model_executor/warmup/__init__.py b/vllm/model_executor/warmup/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py
new file mode 100644
index 0000000000000..74599fa44c88c
--- /dev/null
+++ b/vllm/model_executor/warmup/deep_gemm_warmup.py
@@ -0,0 +1,219 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Warmup deep_gemm kernels.
+DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would
+be used during model execution beforehand.
+"""
+
+import torch
+from tqdm import tqdm
+
+import vllm.envs as envs
+from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
+from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
+ compute_aligned_M, deep_gemm_block_shape)
+from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+from vllm.model_executor.layers.fused_moe.modular_kernel import (
+ FusedMoEModularKernel)
+from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+ TritonOrDeepGemmExperts)
+from vllm.model_executor.layers.linear import LinearBase
+from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
+from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
+
+
+def _extract_data_from_linear_base_module(
+ m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
+ """
+ Extract weights, weight scales and quantization block sizes from the given
+ LinearBase module.
+ """
+ assert isinstance(m, LinearBase)
+ assert isinstance(m.quant_method, Fp8LinearMethod)
+ assert m.quant_method.block_quant
+ assert m.quant_method.quant_config is not None
+
+ w = m.weight
+ ws = m.weight_scale_inv
+ quant_block_size = m.quant_method.quant_config.weight_block_size
+
+ assert isinstance(w, torch.Tensor)
+ assert isinstance(ws, torch.Tensor)
+ assert quant_block_size is not None
+ return (w, ws, quant_block_size)
+
+
+def _extract_data_from_fused_moe_module(
+ m: torch.nn.Module
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
+ """
+ Extract weights, weight scales and num_topk from FusedMoE module.
+ """
+ assert isinstance(m, FusedMoE)
+ w13 = m.w13_weight
+ w13_s = m.w13_weight_scale_inv
+ w2 = m.w2_weight
+ w2_s = m.w2_weight_scale_inv
+ num_topk = m.top_k
+
+ assert isinstance(w13, torch.Tensor)
+ assert isinstance(w13_s, torch.Tensor)
+ assert isinstance(w2, torch.Tensor)
+ assert isinstance(w2_s, torch.Tensor)
+ return w13, w13_s, w2, w2_s, num_topk
+
+
+def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
+ """
+ Return True if the input module/layer could be processed with DeepGEMM.
+ """
+ block_size = deep_gemm_block_shape()[0]
+ if not (isinstance(module, LinearBase)
+ and isinstance(module.quant_method, Fp8LinearMethod)
+ and module.quant_method.block_quant):
+ return False
+
+ w, _, block_sizes = _extract_data_from_linear_base_module(module)
+ return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
+ and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)
+
+
+def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
+ if not (isinstance(module, FusedMoE)
+ and module.moe_config.quant_dtype == torch.float8_e4m3fn
+ and module.moe_config.block_shape == deep_gemm_block_shape()):
+ return False
+
+ if not isinstance(module.quant_method.fused_experts,
+ FusedMoEModularKernel):
+ # fused_experts could invoke deep_gemm_moe_fp8
+ return True
+
+ mk: FusedMoEModularKernel = module.quant_method.fused_experts
+ # Further check if the ModularKernel implementation uses the DeepGemmExperts
+ return isinstance(mk.fused_experts,
+ (DeepGemmExperts, TritonOrDeepGemmExperts))
+
+
+FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
+
+
+def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
+ max_tokens: int):
+ if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
+ return
+
+ n, k = w.size()
+ block_m = deep_gemm_block_shape()[0]
+
+ device = w.device
+ a1q = torch.empty((max_tokens, k),
+ device=device,
+ dtype=torch.float8_e4m3fn)
+ a1q_scales = torch.empty((max_tokens, k // block_m),
+ device=device,
+ dtype=torch.float32)
+ out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
+
+ pbar = tqdm(total=max_tokens,
+ desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
+ num_tokens = max_tokens
+ while num_tokens > 0:
+ fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws),
+ out[:num_tokens])
+ pbar.update(1)
+ num_tokens -= 1
+
+ FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
+
+
+GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
+
+
+def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
+ w2: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ num_topk: int):
+ if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
+ and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
+ return
+
+ assert w1.size(0) == w2.size(0), (
+ "w1 and w2 must have the same number of experts")
+
+ block_m = deep_gemm_block_shape()[0]
+ num_experts = w1.size(0)
+ device = w1.device
+
+ # This is the maximum GroupedGemm M size that we expect to run
+ # the grouped_gemm with.
+ MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
+ num_topk,
+ num_experts,
+ block_m,
+ expert_tokens_meta=None)
+ # Distribute expert-ids evenly.
+ MAX_BLOCKS = MAX_M // block_m
+ expert_ids_block = torch.randint(low=0,
+ high=num_experts,
+ size=(MAX_BLOCKS, ),
+ device=device,
+ dtype=torch.int32)
+ expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
+
+ def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
+
+ _, n, k = w.size()
+ a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
+ a1q_scales = torch.empty((MAX_M, k // block_m),
+ device=device,
+ dtype=torch.float32)
+ out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
+
+ pbar = tqdm(
+ total=MAX_BLOCKS,
+ desc=
+ f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})"
+ )
+ num_tokens = MAX_M
+ while num_tokens > 0:
+ m_grouped_fp8_gemm_nt_contiguous(
+ (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
+ out[:num_tokens], expert_ids[:num_tokens])
+ pbar.update(1)
+ num_tokens = num_tokens - block_m
+
+ for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
+ if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
+ _warmup(w, ws)
+ GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
+
+
+def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
+ dg_modules = [
+ m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)
+ ]
+
+ for dgm in dg_modules:
+ w, ws, _ = _extract_data_from_linear_base_module(dgm)
+ _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
+
+
+def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
+ dg_modules = [
+ m for m in model.modules()
+ if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
+ ]
+
+ for dgm in dg_modules:
+ w13, w13_scale, w2, w2_scale, num_topk = (
+ _extract_data_from_fused_moe_module(dgm))
+ _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
+ w13, w2, w13_scale, w2_scale, num_topk)
+
+
+def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
+ deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
+ deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)
diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py
new file mode 100644
index 0000000000000..10f2dc0252a1d
--- /dev/null
+++ b/vllm/model_executor/warmup/kernel_warmup.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Warmup kernels used during model execution.
+This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
+happen during model execution.
+"""
+import torch
+
+import vllm.envs as envs
+from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
+from vllm.utils.deep_gemm import is_deep_gemm_supported
+
+
+def kernel_warmup(model: torch.nn.Module, max_tokens: int):
+ do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
+ and is_deep_gemm_supported()
+ and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
+ if do_deep_gemm_warmup:
+ deep_gemm_warmup(model, max_tokens)
diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py
index 262b22e554b9a..6074a4d54f223 100644
--- a/vllm/multimodal/cache.py
+++ b/vllm/multimodal/cache.py
@@ -7,9 +7,9 @@ from typing import TypeVar, Union
import torch
-from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
+from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index 18aae35c6fd42..6d4bcef3206c1 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -13,8 +13,8 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
import numpy as np
from typing_extensions import NotRequired, TypeAlias
-from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import LazyLoader, full_groupby, is_list_of
+from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
import torch
diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py
index dca04e9a1e22b..ded56cca80999 100644
--- a/vllm/multimodal/registry.py
+++ b/vllm/multimodal/registry.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
+from functools import lru_cache
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
@@ -86,6 +87,13 @@ class _ProcessorFactories(Generic[_I]):
return self.processor(info, dummy_inputs_builder, cache=cache)
+# Make sure a different cache is used for each model config
+# NOTE: ModelConfig is not hashable so it cannot be passed directly
+@lru_cache(maxsize=1)
+def _get_processor_cache(model_id: str, capacity_gb: int):
+ return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
+
+
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
@@ -95,25 +103,57 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]()
- self._processor_cache: Optional[ProcessingCache] = None
-
def _get_processor_cache(self, model_config: "ModelConfig"):
+ model_id = model_config.model
capacity_gb = model_config.mm_processor_cache_gb
- if capacity_gb is None:
- return None # Overrides `disable_cache` argument
+ return _get_processor_cache(model_id, capacity_gb)
- if self._processor_cache is None:
- self._processor_cache = ProcessingCache(capacity_gb)
-
- return self._processor_cache
-
- def reset_processor_cache(self) -> bool:
+ def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
"""Reset the multi-modal processing cache."""
- if self._processor_cache:
- self._processor_cache.reset()
+ if processor_cache := self._get_processor_cache(model_config):
+ processor_cache.reset()
return True # Success
+ def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool:
+ """Whether the multi-modal input cache should be enabled.
+ NOTE: This is put under MultiModalRegistry on purpose to respect
+ text-only mode for multimodal models.
+ """
+
+ if not self.supports_multimodal_inputs(model_config):
+ return False
+
+ mm_config = model_config.get_multimodal_config()
+
+ return mm_config.mm_processor_cache_gb > 0
+
+ def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
+ """
+ Checks if the model supports multimodal inputs.
+ Returns True if the model is multimodal with any non-zero supported
+ modalities, otherwise returns False, effectively running in
+ text-only mode.
+ """
+ if not model_config.is_multimodal_model:
+ return False
+
+ info = self._create_processing_info(model_config, tokenizer=None)
+ supported_modalities = info.get_supported_mm_limits()
+
+ mm_config = model_config.get_multimodal_config()
+
+ # Check if all supported modalities have limit == 0
+ if all(
+ mm_config.get_limit_per_prompt(modality) == 0
+ for modality in supported_modalities):
+ logger.info_once(
+ "All limits of multimodal modalities supported by the model "
+ "are set to 0, running in text-only mode.")
+ return False
+
+ return True
+
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
@@ -238,6 +278,26 @@ class MultiModalRegistry:
model_cls, _ = get_model_architecture(model_config)
return model_cls
+ def _create_processing_ctx(
+ self,
+ model_config: "ModelConfig",
+ tokenizer: Optional[AnyTokenizer] = None,
+ ) -> InputProcessingContext:
+ if tokenizer is None and not model_config.skip_tokenizer_init:
+ tokenizer = cached_tokenizer_from_config(model_config)
+ return InputProcessingContext(model_config, tokenizer)
+
+ def _create_processing_info(
+ self,
+ model_config: "ModelConfig",
+ *,
+ tokenizer: Optional[AnyTokenizer] = None,
+ ) -> BaseProcessingInfo:
+ model_cls = self._get_model_cls(model_config)
+ factories = self._processor_factories[model_cls]
+ ctx = self._create_processing_ctx(model_config, tokenizer)
+ return factories.info(ctx)
+
def create_processor(
self,
model_config: "ModelConfig",
@@ -251,15 +311,13 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
- if tokenizer is None and not model_config.skip_tokenizer_init:
- tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
disable_cache = not model_config.enable_mm_processor_cache
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
- ctx = InputProcessingContext(model_config, tokenizer)
+ ctx = self._create_processing_ctx(model_config, tokenizer)
cache = None if disable_cache else self._get_processor_cache(
model_config)
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index 31a67183ff12c..0b16a8e1d1d8b 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -91,8 +91,8 @@ class CpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
- block_size: int, use_v1: bool,
- use_mla: bool) -> str:
+ block_size: int, use_v1: bool, use_mla: bool,
+ has_sink: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index dd9356e399c9d..c876c52a2e9c9 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -222,8 +222,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
- kv_cache_dtype, block_size, use_v1,
- use_mla) -> str:
+ kv_cache_dtype, block_size, use_v1, use_mla,
+ has_sink) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
@@ -321,6 +321,9 @@ class CudaPlatformBase(Platform):
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
+ if has_sink:
+ logger.info_once("Using Triton backend on V1 engine.")
+ return TRITON_ATTN_VLLM_V1
if is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False):
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index a85b583abc2ce..91d5314900c87 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -196,8 +196,8 @@ class Platform:
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
- block_size: int, use_v1: bool,
- use_mla: bool) -> str:
+ block_size: int, use_v1: bool, use_mla: bool,
+ has_sink: bool) -> str:
"""Get the attention backend class of a device."""
return ""
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index d26e4b3350381..8005830f55cef 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -188,8 +188,8 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
- kv_cache_dtype, block_size, use_v1,
- use_mla) -> str:
+ kv_cache_dtype, block_size, use_v1, use_mla,
+ has_sink) -> str:
if use_mla:
from vllm.attention.backends.rocm_aiter_mla import (
is_aiter_mla_enabled)
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index 146801c9d7739..c56096d93612d 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -46,8 +46,8 @@ class TpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
- block_size: int, use_v1: bool,
- use_mla: bool) -> str:
+ block_size: int, use_v1: bool, use_mla: bool,
+ has_sink) -> str:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index d8a663f2f0c4a..abd58dbbcbf45 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -35,8 +35,8 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
- block_size: int, use_v1: bool,
- use_mla: bool) -> str:
+ block_size: int, use_v1: bool, use_mla: bool,
+ has_sink: bool) -> str:
if selected_backend is not None and selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
use_v1 = envs.VLLM_USE_V1
diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py
index 51c78ddc1a9d5..1a1760df82c03 100644
--- a/vllm/plugins/__init__.py
+++ b/vllm/plugins/__init__.py
@@ -4,8 +4,6 @@
import logging
from typing import Any, Callable
-import torch
-
import vllm.envs as envs
logger = logging.getLogger(__name__)
@@ -68,13 +66,6 @@ def load_general_plugins():
return
plugins_loaded = True
- # some platform-specific configurations
- from vllm.platforms import current_platform
-
- if current_platform.is_xpu():
- # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
- torch._dynamo.config.disable = True
-
plugins = load_plugins_by_group(group=DEFAULT_PLUGINS_GROUP)
# general plugins, we only need to execute the loaded functions
for func in plugins.values():
diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py
index 7077f68353fc5..29f037b4372cd 100644
--- a/vllm/pooling_params.py
+++ b/vllm/pooling_params.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Optional
import msgspec
@@ -46,6 +46,9 @@ class PoolingParams(
requires_token_ids: bool = False
"""Internal use only."""
+ extra_kwargs: Optional[dict[str, Any]] = None
+ """Internal use only."""
+
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property
@@ -167,7 +170,8 @@ class PoolingParams(
f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, "
- f"requires_token_ids={self.requires_token_ids})")
+ f"requires_token_ids={self.requires_token_ids}, "
+ f"extra_kwargs={self.extra_kwargs})")
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 52e4cbd096153..df4cca9ba1147 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -103,113 +103,89 @@ class SamplingParams(
Overall, we follow the sampling parameters from the OpenAI text completion
API (https://platform.openai.com/docs/api-reference/completions/create).
In addition, we support beam search, which is not supported by OpenAI.
-
- Args:
- n: Number of output sequences to return for the given prompt.
- best_of: Number of output sequences that are generated from the prompt.
- From these `best_of` sequences, the top `n` sequences are returned.
- `best_of` must be greater than or equal to `n`. By default,
- `best_of` is set to `n`. Warning, this is only supported in V0.
- presence_penalty: Float that penalizes new tokens based on whether they
- appear in the generated text so far. Values > 0 encourage the model
- to use new tokens, while values < 0 encourage the model to repeat
- tokens.
- frequency_penalty: Float that penalizes new tokens based on their
- frequency in the generated text so far. Values > 0 encourage the
- model to use new tokens, while values < 0 encourage the model to
- repeat tokens.
- repetition_penalty: Float that penalizes new tokens based on whether
- they appear in the prompt and the generated text so far. Values > 1
- encourage the model to use new tokens, while values < 1 encourage
- the model to repeat tokens.
- temperature: Float that controls the randomness of the sampling. Lower
- values make the model more deterministic, while higher values make
- the model more random. Zero means greedy sampling.
- top_p: Float that controls the cumulative probability of the top tokens
- to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
- top_k: Integer that controls the number of top tokens to consider. Set
- to 0 (or -1) to consider all tokens.
- min_p: Float that represents the minimum probability for a token to be
- considered, relative to the probability of the most likely token.
- Must be in [0, 1]. Set to 0 to disable this.
- seed: Random seed to use for the generation.
- stop: list of strings that stop the generation when they are generated.
- The returned output will not contain the stop strings.
- stop_token_ids: list of tokens that stop the generation when they are
- generated. The returned output will contain the stop tokens unless
- the stop tokens are special tokens.
- bad_words: list of words that are not allowed to be generated.
- More precisely, only the last token of a corresponding
- token sequence is not allowed when the next generated token
- can complete the sequence.
- include_stop_str_in_output: Whether to include the stop strings in
- output text. Defaults to False.
- ignore_eos: Whether to ignore the EOS token and continue generating
- tokens after the EOS token is generated.
- max_tokens: Maximum number of tokens to generate per output sequence.
- min_tokens: Minimum number of tokens to generate per output sequence
- before EOS or stop_token_ids can be generated
- logprobs: Number of log probabilities to return per output token.
- When set to None, no probability is returned. If set to a non-None
- value, the result includes the log probabilities of the specified
- number of most likely tokens, as well as the chosen tokens.
- Note that the implementation follows the OpenAI API: The API will
- always return the log probability of the sampled token, so there
- may be up to `logprobs+1` elements in the response.
- When set to -1, return all `vocab_size` log probabilities.
- prompt_logprobs: Number of log probabilities to return per prompt token.
- detokenize: Whether to detokenize the output. Defaults to True.
- skip_special_tokens: Whether to skip special tokens in the output.
- spaces_between_special_tokens: Whether to add spaces between special
- tokens in the output. Defaults to True.
- logits_processors: list of functions that modify logits based on
- previously generated tokens, and optionally prompt tokens as
- a first argument.
- truncate_prompt_tokens: If set to -1, will use the truncation size
- supported by the model. If set to an integer k, will use only
- the last k tokens from the prompt (i.e., left truncation).
- Defaults to None (i.e., no truncation).
- guided_decoding: If provided, the engine will construct a guided
- decoding logits processor from these parameters. Defaults to None.
- logit_bias: If provided, the engine will construct a logits processor
- that applies these logit biases. Defaults to None.
- allowed_token_ids: If provided, the engine will construct a logits
- processor which only retains scores for the given token ids.
- Defaults to None.
- extra_args: Arbitrary additional args, that can be used by custom
- sampling implementations, plugins, etc. Not used by any in-tree
- sampling implementations.
"""
n: int = 1
+ """Number of output sequences to return for the given prompt."""
best_of: Optional[int] = None
+ """Number of output sequences that are generated from the prompt. From
+ these `best_of` sequences, the top `n` sequences are returned. `best_of`
+ must be greater than or equal to `n`. By default, `best_of` is set to `n`.
+ Warning, this is only supported in V0."""
_real_n: Optional[int] = None
presence_penalty: float = 0.0
+ """Penalizes new tokens based on whether they appear in the generated text
+ so far. Values > 0 encourage the model to use new tokens, while values < 0
+ encourage the model to repeat tokens."""
frequency_penalty: float = 0.0
+ """Penalizes new tokens based on their frequency in the generated text so
+ far. Values > 0 encourage the model to use new tokens, while values < 0
+ encourage the model to repeat tokens."""
repetition_penalty: float = 1.0
+ """Penalizes new tokens based on whether they appear in the prompt and the
+ generated text so far. Values > 1 encourage the model to use new tokens,
+ while values < 1 encourage the model to repeat tokens."""
temperature: float = 1.0
+ """Controls the randomness of the sampling. Lower values make the model
+ more deterministic, while higher values make the model more random. Zero
+ means greedy sampling."""
top_p: float = 1.0
+ """Controls the cumulative probability of the top tokens to consider. Must
+ be in (0, 1]. Set to 1 to consider all tokens."""
top_k: int = 0
+ """Controls the number of top tokens to consider. Set to 0 (or -1) to
+ consider all tokens."""
min_p: float = 0.0
+ """Represents the minimum probability for a token to be considered,
+ relative to the probability of the most likely token. Must be in [0, 1].
+ Set to 0 to disable this."""
seed: Optional[int] = None
+ """Random seed to use for the generation."""
stop: Optional[Union[str, list[str]]] = None
+ """String(s) that stop the generation when they are generated. The returned
+ output will not contain the stop strings."""
stop_token_ids: Optional[list[int]] = None
+ """Token IDs that stop the generation when they are generated. The returned
+ output will contain the stop tokens unless the stop tokens are special
+ tokens."""
ignore_eos: bool = False
+ """Whether to ignore the EOS token and continue generating
+ tokens after the EOS token is generated."""
max_tokens: Optional[int] = 16
+ """Maximum number of tokens to generate per output sequence."""
min_tokens: int = 0
+ """Minimum number of tokens to generate per output sequence before EOS or
+ `stop_token_ids` can be generated"""
logprobs: Optional[int] = None
+ """Number of log probabilities to return per output token. When set to
+ `None`, no probability is returned. If set to a non-`None` value, the
+ result includes the log probabilities of the specified number of most
+ likely tokens, as well as the chosen tokens. Note that the implementation
+ follows the OpenAI API: The API will always return the log probability of
+ the sampled token, so there may be up to `logprobs+1` elements in the
+ response. When set to -1, return all `vocab_size` log probabilities."""
prompt_logprobs: Optional[int] = None
+ """Number of log probabilities to return per prompt token."""
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
detokenize: bool = True
+ """Whether to detokenize the output."""
skip_special_tokens: bool = True
+ """Whether to skip special tokens in the output."""
spaces_between_special_tokens: bool = True
+ """Whether to add spaces between special tokens in the output."""
# Optional[list[LogitsProcessor]] type. We use Any here because
# Optional[list[LogitsProcessor]] type is not supported by msgspec.
logits_processors: Optional[Any] = None
+ """Functions that modify logits based on previously generated tokens, and
+ optionally prompt tokens as a first argument."""
include_stop_str_in_output: bool = False
+ """Whether to include the stop strings in output text."""
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
+ """If set to -1, will use the truncation size supported by the model. If
+ set to an integer k, will use only the last k tokens from the prompt
+ (i.e., left truncation). If set to `None`, truncation is disabled."""
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input.
@@ -219,12 +195,24 @@ class SamplingParams(
# Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None
+ """If provided, the engine will construct a guided decoding logits
+ processor from these parameters."""
logit_bias: Optional[dict[int, float]] = None
+ """If provided, the engine will construct a logits processor that applies
+ these logit biases."""
allowed_token_ids: Optional[list[int]] = None
+ """If provided, the engine will construct a logits processor which only
+ retains scores for the given token ids."""
extra_args: Optional[dict[str, Any]] = None
+ """Arbitrary additional args, that can be used by custom sampling
+ implementations, plugins, etc. Not used by any in-tree sampling
+ implementations."""
# Fields used for bad words
bad_words: Optional[list[str]] = None
+ """Words that are not allowed to be generated. More precisely, only the
+ last token of a corresponding token sequence is not allowed when the next
+ generated token can complete the sequence."""
_bad_words_token_ids: Optional[list[list[int]]] = None
@staticmethod
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index bce24ef74cdef..02ea0814ddefa 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -32,11 +32,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
EAGLEConfig, JAISConfig,
KimiVLConfig, MedusaConfig,
- MllamaConfig, MLPSpeculatorConfig,
+ MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
- NemotronConfig, NVLM_D_Config,
- OvisConfig, RWConfig,
- SpeculatorsConfig,
+ NemotronConfig, OvisConfig,
+ RWConfig, SpeculatorsConfig,
Step3TextConfig, Step3VLConfig,
UltravoxConfig)
# yapf: enable
@@ -68,10 +67,6 @@ def _get_hf_token() -> Optional[str]:
return None
-_CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = {
- "mllama": MllamaConfig
-}
-
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"deepseek_vl_v2": DeepseekVLV2Config,
@@ -85,18 +80,30 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"eagle": EAGLEConfig,
"speculators": SpeculatorsConfig,
"nemotron": NemotronConfig,
- "NVLM_D": NVLM_D_Config,
"ovis": OvisConfig,
"ultravox": UltravoxConfig,
"step3_vl": Step3VLConfig,
"step3_text": Step3TextConfig,
- **_CONFIG_REGISTRY_OVERRIDE_HF
}
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config",
}
+_AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
+ "internvl_chat": {
+ "has_no_defaults_at_init": True
+ },
+ # transformers regards mllama as is_encoder_decoder=False
+ # vllm needs is_encoder_decoder=True to enable cross-attention
+ "mllama": {
+ "is_encoder_decoder": True
+ },
+ "NVLM_D": {
+ "has_no_defaults_at_init": True
+ },
+}
+
class ConfigFormat(str, enum.Enum):
AUTO = "auto"
@@ -254,7 +261,8 @@ def _uses_mrope(config: PretrainedConfig) -> bool:
def uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model with this config uses M-ROPE."""
- return _uses_mrope(config) or thinker_uses_mrope(config)
+ return _uses_mrope(config) or _uses_mrope(
+ config.get_text_config()) or thinker_uses_mrope(config)
def thinker_uses_mrope(config: PretrainedConfig) -> bool:
@@ -272,11 +280,32 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
- text_config = getattr(config, "text_config", None)
- if text_config is not None:
- return is_encoder_decoder(text_config)
- return getattr(config, "is_encoder_decoder", False)
+ def _is_encoder_decoder(config: PretrainedConfig) -> bool:
+ return getattr(config, "is_encoder_decoder", False)
+
+ return (_is_encoder_decoder(config)
+ or _is_encoder_decoder(config.get_text_config()))
+
+
+def is_interleaved(config: PretrainedConfig) -> bool:
+ """
+ Detect if the model with this config is used with interleaved attention.
+ """
+ text_config = config.get_text_config()
+ if layer_types := getattr(text_config, "layer_types", None):
+ interleaved_types = {"full_attention", "sliding_attention"}
+ return interleaved_types.issubset(layer_types)
+ return False
+
+
+def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
+ """
+ Update kwargs for AutoConfig initialization based on model_type
+ """
+ if model_type in _AUTO_CONFIG_KWARGS_OVERRIDES:
+ kwargs.update(_AUTO_CONFIG_KWARGS_OVERRIDES[model_type])
+ return kwargs
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
@@ -285,7 +314,6 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
if hasattr(config, old_attr):
if not hasattr(config, new_attr):
config.update({new_attr: getattr(config, old_attr)})
- delattr(config, old_attr)
logger.debug("Remapped config attribute '%s' to '%s'", old_attr,
new_attr)
return config
@@ -396,15 +424,14 @@ def get_config(
)
else:
try:
+ kwargs = _maybe_update_auto_config_kwargs(
+ kwargs, model_type=model_type)
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
- # some old custom model's config needs
- # `has_no_defaults_at_init=True` to work.
- has_no_defaults_at_init=trust_remote_code,
**kwargs,
)
except ValueError as e:
@@ -422,6 +449,23 @@ def get_config(
raise e
config = _maybe_remap_hf_config_attrs(config)
+ # Phi4Flash misuses this config as list[int]. Convert it to int and add
+ # the layer_types list[str] to make it HF compatible
+ if (config.model_type == "phi4flash"):
+ # TODO: Remove after the following PR is merged:
+ # https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
+ if not hasattr(config, "layer_types"):
+ config.layer_types = [
+ "sliding_attention" if i < config.num_hidden_layers // 2
+ and i % 2 == 1 else "full_attention"
+ for i in range(config.num_hidden_layers)
+ ]
+ # TODO: Remove after the following PR is merged:
+ # https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
+ if isinstance(config.sliding_window, list):
+ config.sliding_window = next(
+ filter(None, config.sliding_window), None)
+
elif config_format == ConfigFormat.MISTRAL:
# This function loads a params.json config which
# should be used when loading models in mistral format
@@ -433,6 +477,18 @@ def get_config(
config_dict["max_position_embeddings"] = max_position_embeddings
config = adapt_config_dict(config_dict)
+
+ # Mistral configs may define sliding_window as list[int]. Convert it
+ # to int and add the layer_types list[str] to make it HF compatible
+ if ((sliding_window := getattr(config, "sliding_window", None))
+ and isinstance(sliding_window, list)):
+ pattern_repeats = config.num_hidden_layers // len(sliding_window)
+ layer_types = sliding_window * pattern_repeats
+ config.layer_types = [
+ "full_attention" if layer_type is None else "sliding_attention"
+ for layer_type in layer_types
+ ]
+ config.sliding_window = next(filter(None, sliding_window), None)
else:
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 82d24bb16ba5a..8339c55bcf808 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -17,13 +17,11 @@ from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
-from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
-from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
@@ -34,18 +32,16 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
"ChatGLMConfig",
"DeepseekVLV2Config",
+ "EAGLEConfig",
"RWConfig",
"JAISConfig",
"MedusaConfig",
- "EAGLEConfig",
- "MllamaConfig",
"MLPSpeculatorConfig",
"MoonViTConfig",
"KimiVLConfig",
"NemotronConfig",
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
- "NVLM_D_Config",
"OvisConfig",
"SpeculatorsConfig",
"UltravoxConfig",
diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py
index 5445a333c493e..bc249c5836034 100644
--- a/vllm/transformers_utils/configs/eagle.py
+++ b/vllm/transformers_utils/configs/eagle.py
@@ -45,6 +45,7 @@ class EAGLEConfig(PretrainedConfig):
# Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM
+ # LlamaForCausalLM -> Eagle3LlamaForCausalLM
if method == "eagle":
assert self.model is not None, \
"model should not be None when method is eagle"
@@ -56,8 +57,8 @@ class EAGLEConfig(PretrainedConfig):
assert self.model is not None, \
"model should not be None when method is eagle3"
kwargs["architectures"] = [
- f"Eagle3{arch}" if not arch.startswith("Eagle3") \
- else arch for arch in self.model.architectures
+ arch if arch.startswith("Eagle3") or arch.endswith("Eagle3")
+ else f"Eagle3{arch}" for arch in self.model.architectures
]
else:
raise ValueError(f"Invalid method {method}. \
diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py
deleted file mode 100644
index f0cd2d52a529e..0000000000000
--- a/vllm/transformers_utils/configs/mllama.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from transformers.models.mllama import configuration_mllama as mllama_hf_config
-
-
-class MllamaTextConfig(mllama_hf_config.MllamaTextConfig):
- '''
- Use this class to override is_encoder_decoder:
- - transformers regards mllama as is_encoder_decoder=False
- - vllm needs is_encoder_decoder=True to enable cross-attention
- '''
-
- def __init__(
- self,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.is_encoder_decoder = True
-
-
-class MllamaConfig(mllama_hf_config.MllamaConfig):
-
- def __init__(
- self,
- text_config=None,
- **kwargs,
- ):
- if isinstance(text_config, dict):
- text_config = MllamaTextConfig(**text_config)
- super().__init__(text_config=text_config, **kwargs)
diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py
index 457b3371e90db..027f2911543f5 100644
--- a/vllm/transformers_utils/configs/nemotron_h.py
+++ b/vllm/transformers_utils/configs/nemotron_h.py
@@ -151,7 +151,7 @@ class NemotronHConfig(PretrainedConfig):
num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
num_attention_heads=32,
- attention_head_dim=128,
+ head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups
mlp_hidden_act="relu2",
attention_bias=False,
@@ -194,7 +194,7 @@ class NemotronHConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern
self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
+ self.head_dim = head_dim
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
diff --git a/vllm/transformers_utils/configs/nvlm_d.py b/vllm/transformers_utils/configs/nvlm_d.py
deleted file mode 100644
index edfc506882ff5..0000000000000
--- a/vllm/transformers_utils/configs/nvlm_d.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-# Adapted from
-# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py
-# --------------------------------------------------------
-# NVLM-D
-# Copyright (c) 2024 NVIDIA
-# Licensed under Apache 2.0 License [see LICENSE for details]
-# --------------------------------------------------------
-from transformers import Qwen2Config
-from transformers.configuration_utils import PretrainedConfig
-
-
-class NVLM_D_Config(PretrainedConfig):
- model_type = 'NVLM_D'
- is_composition = True
-
- def __init__(self, vision_config=None, llm_config=None, **kwargs):
- super().__init__(**kwargs)
-
- # Handle vision_config initialization
- if vision_config is None:
- vision_config = {}
-
- # Handle llm_config initialization
- if llm_config is None:
- llm_config = {}
-
- self.vision_config = PretrainedConfig(**vision_config)
- self.text_config = Qwen2Config(**llm_config)
diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py
index ce62282c2199f..095829db83944 100644
--- a/vllm/utils/__init__.py
+++ b/vllm/utils/__init__.py
@@ -47,7 +47,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
- Optional, TextIO, Tuple, TypeVar, Union, cast, overload)
+ Optional, TextIO, TypeVar, Union, cast, overload)
from urllib.parse import urlparse
from uuid import uuid4
@@ -687,19 +687,30 @@ class AsyncMicrobatchTokenizer:
max_length = kwargs.get("max_length")
if not truncation:
- return ("encode", add_special_tokens, False, None)
+ return "encode", add_special_tokens, False, None
model_max = getattr(self.tokenizer, "model_max_length", None)
if max_length is None or (model_max is not None
and max_length == model_max):
- return ("encode", add_special_tokens, True, "model_max")
+ return "encode", add_special_tokens, True, "model_max"
- return ("encode", "other")
+ return "encode", "other"
def __del__(self):
- for task in self._batcher_tasks:
- if not task.done():
- task.cancel()
+ if ((tasks := getattr(self, "_batcher_tasks", None))
+ and (loop := getattr(self, "_loop", None))
+ and not loop.is_closed()):
+
+ def cancel_tasks():
+ for task in tasks:
+ task.cancel()
+
+ loop.call_soon_threadsafe(cancel_tasks)
+
+
+def cancel_task_threadsafe(task: Task):
+ if task and not task.done() and not (loop := task.get_loop()).is_closed():
+ loop.call_soon_threadsafe(task.cancel)
def make_async(
@@ -850,7 +861,7 @@ def is_valid_ipv6_address(address: str) -> bool:
return False
-def split_host_port(host_port: str) -> Tuple[str, int]:
+def split_host_port(host_port: str) -> tuple[str, int]:
# ipv6
if host_port.startswith('['):
host, port = host_port.rsplit(']', 1)
@@ -1658,11 +1669,21 @@ class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
_deprecated: set[Action] = set()
+ _json_tip: str = (
+ "When passing JSON CLI arguments, the following sets of arguments "
+ "are equivalent:\n"
+ ' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
+ " --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
+ "Additionally, list elements can be passed individually using +:\n"
+ ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
+ " --json-arg.key4+ value3 --json-arg.key4+=\'value4,value5\'\n\n")
def __init__(self, *args, **kwargs):
- # Set the default 'formatter_class' to SortedHelpFormatter
- if 'formatter_class' not in kwargs:
- kwargs['formatter_class'] = SortedHelpFormatter
+ # Set the default "formatter_class" to SortedHelpFormatter
+ if "formatter_class" not in kwargs:
+ kwargs["formatter_class"] = SortedHelpFormatter
+ # Pop kwarg "add_json_tip" to control whether to add the JSON tip
+ self.add_json_tip = kwargs.pop("add_json_tip", True)
super().__init__(*args, **kwargs)
if sys.version_info < (3, 13):
@@ -1704,6 +1725,14 @@ class FlexibleArgumentParser(ArgumentParser):
self._action_groups.append(group)
return group
+ def format_help(self) -> str:
+ # Add tip about JSON arguments to the epilog
+ epilog = self.epilog or ""
+ if (self.add_json_tip
+ and not epilog.startswith(FlexibleArgumentParser._json_tip)):
+ self.epilog = FlexibleArgumentParser._json_tip + epilog
+ return super().format_help()
+
def parse_args( # type: ignore[override]
self,
args: list[str] | None = None,
@@ -3243,6 +3272,12 @@ def has_deep_gemm() -> bool:
return _has_module("deep_gemm")
+def has_triton_kernels() -> bool:
+ """Whether the optional `triton_kernels` package is available."""
+
+ return _has_module("triton_kernels")
+
+
def set_process_title(name: str,
suffix: str = "",
append: bool = False) -> None:
diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py
index 0edfb01cde9d6..861d9c0c0005d 100644
--- a/vllm/utils/deep_gemm.py
+++ b/vllm/utils/deep_gemm.py
@@ -14,6 +14,7 @@ from typing import Any, Callable, NoReturn
import torch
import vllm.envs as envs
+from vllm.logger import logger
from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm
@@ -30,19 +31,37 @@ def is_deep_gemm_supported() -> bool:
@functools.cache
-def is_blackwell_deep_gemm_used() -> bool:
- """Return ``True`` if vLLM is configured to use DeepGEMM on a
- Blackwell-class GPU.
+def is_blackwell_deep_gemm_e8m0_used() -> bool:
+ """Return ``True`` if vLLM is configured to use DeepGEMM "
+ "E8M0 scale on a Blackwell-class GPU.
"""
- if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()):
+ if not (envs.VLLM_USE_DEEP_GEMM):
+ logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.")
+ return False
+
+ if not has_deep_gemm():
+ logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.")
+ return False
+
+ if not envs.VLLM_USE_DEEP_GEMM_E8M0:
+ logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
return False
_lazy_init()
+
if _fp8_gemm_nt_impl is None:
+ logger.debug_once(
+ "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False
- return (current_platform.is_cuda()
- and current_platform.is_device_capability(100))
+ enabled = (current_platform.is_cuda()
+ and current_platform.has_device_capability(100))
+ if enabled:
+ logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
+ else:
+ logger.debug_once(
+ "DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
+ return enabled
def _missing(*_: Any, **__: Any) -> NoReturn:
@@ -57,6 +76,14 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
if hasattr(module, new):
return getattr(module, new)
if hasattr(module, old):
+ # TODO(wentao): deprecate old symbol in the future.
+ logger.warning_once(
+ "Found legacy DeepGEMM symbol `%s`. Please upgrade the `deep_gemm` "
+ "package so that `%s` is available. Support for the legacy symbol "
+ "will be removed in a future vLLM release.",
+ old,
+ new,
+ )
return getattr(module, old)
return None
@@ -100,21 +127,30 @@ def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
- return _fp8_gemm_nt_impl(*args, **kwargs)
+ return _fp8_gemm_nt_impl(
+ *args,
+ disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
+ **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
return _missing(*args, **kwargs)
- return _grouped_impl(*args, **kwargs)
+ return _grouped_impl(
+ *args,
+ disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
+ **kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
- return _grouped_masked_impl(*args, **kwargs)
+ return _grouped_masked_impl(
+ *args,
+ disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
+ **kwargs)
def _ceil_to_ue8m0(x: torch.Tensor):
@@ -172,6 +208,6 @@ __all__ = [
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"per_block_cast_to_fp8",
- "is_blackwell_deep_gemm_used",
+ "is_blackwell_deep_gemm_e8m0_used",
"is_deep_gemm_supported",
]
diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py
index 32c52612ca16f..5998d4c3127f6 100644
--- a/vllm/utils/flashinfer.py
+++ b/vllm/utils/flashinfer.py
@@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave")
+trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
+ "flashinfer", "trtllm_fp4_block_scale_moe")
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
@@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "nvfp4_block_scale_interleave"),
+ ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
@@ -188,6 +191,7 @@ __all__ = [
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"nvfp4_block_scale_interleave",
+ "trtllm_fp4_block_scale_moe",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",
diff --git a/vllm/jsontree.py b/vllm/utils/jsontree.py
similarity index 100%
rename from vllm/jsontree.py
rename to vllm/utils/jsontree.py
diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py
index 343df71e1058d..4c3acf0094c74 100644
--- a/vllm/utils/tensor_schema.py
+++ b/vllm/utils/tensor_schema.py
@@ -60,6 +60,9 @@ class TensorSchema:
def __getitem__(self, item) -> Any:
return getattr(self, item)
+ def get(self, item, default=None) -> Any:
+ return getattr(self, item, default)
+
def _match_shape_with_dynamic(self, actual: tuple[int, ...],
reference: tuple[int, ...],
expected_shape: tuple[Union[int, str], ...],
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index 95ba56b359379..a411477bc3e33 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -384,6 +384,8 @@ class FlashAttentionImpl(AttentionImpl):
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
+ elif attn_type == AttentionType.ENCODER_ONLY:
+ self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index 1fcb190286329..c85d8bce31f5d 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache
- prefill_use_trtllm = use_trtllm_attention(
+ prefill_use_trtllm = not cache_dtype.startswith("fp8") \
+ and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim)
decode_use_trtllm = use_trtllm_attention(
diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py
new file mode 100644
index 0000000000000..f08b6d7f177c7
--- /dev/null
+++ b/vllm/v1/attention/backends/linear_attn.py
@@ -0,0 +1,67 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from dataclasses import dataclass
+from typing import ClassVar
+
+import torch
+
+from vllm.attention.backends.abstract import AttentionBackend
+from vllm.config import VllmConfig
+from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
+ CommonAttentionMetadata,
+ split_decodes_and_prefills)
+from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
+
+
+class LinearAttentionBackend(AttentionBackend):
+
+ @staticmethod
+ def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
+ return LinearAttentionMetadataBuilder
+
+
+@dataclass
+class LinearAttentionMetadata:
+ num_prefills: int
+ num_prefill_tokens: int
+ num_decodes: int
+ num_decode_tokens: int
+ query_start_loc: torch.Tensor
+ seq_lens: torch.Tensor
+
+ state_indices_tensor: torch.Tensor # shape: [batch,]
+
+
+class LinearAttentionMetadataBuilder(
+ AttentionMetadataBuilder[LinearAttentionMetadata]):
+
+ reorder_batch_threshold: ClassVar[int] = 1
+
+ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
+ vllm_config: VllmConfig, device: torch.device):
+ assert isinstance(kv_cache_spec, MambaSpec)
+ self.kv_cache_spec = kv_cache_spec
+
+ def build(self,
+ common_prefix_len: int,
+ common_attn_metadata: CommonAttentionMetadata,
+ fast_build: bool = False) -> LinearAttentionMetadata:
+ query_start_loc = common_attn_metadata.query_start_loc
+ seq_lens = common_attn_metadata.seq_lens
+
+ state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
+
+ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
+ split_decodes_and_prefills(common_attn_metadata,
+ decode_threshold=1))
+
+ attn_metadata = LinearAttentionMetadata(
+ num_prefills=num_prefills,
+ num_prefill_tokens=num_prefill_tokens,
+ num_decodes=num_decodes,
+ num_decode_tokens=num_decode_tokens,
+ query_start_loc=query_start_loc,
+ seq_lens=seq_lens,
+ state_indices_tensor=state_indices_tensor,
+ )
+ return attn_metadata
diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py
index 66a8d91db89c2..7c1226049f696 100644
--- a/vllm/v1/attention/backends/mamba_attn.py
+++ b/vllm/v1/attention/backends/mamba_attn.py
@@ -7,8 +7,10 @@ from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
+from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
-from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
+from vllm.v1.attention.backends.utils import (AttentionCGSupport,
+ AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
+ attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
+ AttentionCGSupport.PURE_DECODE_ONLY
reorder_batch_threshold: ClassVar[int] = 1
@@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
+ self.vllm_config = vllm_config
+ self.compilation_config = vllm_config.compilation_config
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
+ self.decode_cudagraph_max_bs = min(
+ self.vllm_config.scheduler_config.max_num_seqs,
+ self.compilation_config.max_capture_size)
+ self.state_indices_tensor = torch.empty(
+ (self.decode_cudagraph_max_bs, ),
+ dtype=torch.int32,
+ device=device,
+ )
def build(self,
common_prefix_len: int,
@@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size,
num_prefill_tokens))
+ elif num_decodes <= self.decode_cudagraph_max_bs:
+ # Pad state tensor for CUDA graph
+ num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
+ self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
+ non_blocking=True)
+ state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
+ state_indices_tensor[num_decodes:] = PAD_SLOT_ID
+
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
@@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
+
+ def build_for_cudagraph_capture(
+ self, common_attn_metadata: CommonAttentionMetadata):
+ """
+ This method builds the metadata for full cudagraph capture.
+ Currently, only decode is supported for full cudagraphs with Mamba.
+ """
+ m = common_attn_metadata
+
+ assert m.num_reqs == m.num_actual_tokens, \
+ "Mamba only supports decode-only full CUDAGraph capture. " \
+ "Make sure all cudagraph capture sizes <= max_num_seq."
+
+ m.max_query_len = 1 # decode-only
+
+ return self.build(0, m)
+
+ def can_run_in_cudagraph(
+ self, common_attn_metadata: CommonAttentionMetadata) -> bool:
+ return common_attn_metadata.max_query_len == 1
diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py
index f56f2fb7bf699..852e0dfe1b312 100644
--- a/vllm/v1/attention/backends/mamba_selectors.py
+++ b/vllm/v1/attention/backends/mamba_selectors.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import AttentionBackend
+from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
@@ -8,9 +9,10 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
if mamba_type == "mamba1":
return Mamba1AttentionBackend
-
if mamba_type == "mamba2":
return Mamba2AttentionBackend
+ if mamba_type == "linear_attention":
+ return LinearAttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.")
diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py
index b5aecff9937f3..2b0f52cf80bfa 100644
--- a/vllm/v1/attention/backends/mla/flashmla.py
+++ b/vllm/v1/attention/backends/mla/flashmla.py
@@ -70,6 +70,22 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
+ device_properties = torch.cuda.get_device_properties(self.device)
+ num_sms = device_properties.multi_processor_count
+
+ if self.compilation_config.full_cuda_graph:
+ self.cg_buf_tile_scheduler_metadata = torch.zeros(
+ # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
+ # TileSchedulerMetaDataSize = 8
+ (num_sms, 8),
+ device=self.device,
+ dtype=torch.int32,
+ )
+ self.cg_buf_num_splits = torch.empty(
+ (vllm_config.scheduler_config.max_num_seqs + 1),
+ device=self.device,
+ dtype=torch.int32)
+
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
@@ -80,28 +96,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
)
if self.compilation_config.full_cuda_graph:
- # First time around (CUDAGraph capture), allocate the static buffer
- if self.cg_buf_tile_scheduler_metadata is None:
- self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
- self.cg_buf_num_splits = num_splits
- else:
- assert self.cg_buf_num_splits is not None
+ assert self.cg_buf_tile_scheduler_metadata is not None
+ assert self.cg_buf_num_splits is not None
- # Metadata per-SM, fixed size (#SMs, TileMetadataSize)
- assert (self.cg_buf_tile_scheduler_metadata.size() ==
- tile_scheduler_metadata.size())
- self.cg_buf_tile_scheduler_metadata.\
- copy_(tile_scheduler_metadata)
- tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
+ sm_parts = tile_scheduler_metadata.size(0)
+ # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
+ assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
+ tile_scheduler_metadata_view = \
+ self.cg_buf_tile_scheduler_metadata[:sm_parts]
+ tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
+ tile_scheduler_metadata = tile_scheduler_metadata_view
- # Num splits is per-batch, varying size (batch_size,)
- n = num_splits.size(0)
- # make sure static buffer is large enough
- assert n <= self.cg_buf_num_splits.size(0)
- num_splits_view = self.cg_buf_num_splits[:n]
- num_splits_view.copy_(num_splits)
- self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
- num_splits = num_splits_view
+ # Num splits is per-batch, varying size (batch_size,)
+ n = num_splits.size(0)
+ # make sure static buffer is large enough
+ assert n <= self.cg_buf_num_splits.size(0)
+ num_splits_view = self.cg_buf_num_splits[:n]
+ num_splits_view.copy_(num_splits)
+ # Num splits needs to monotonically increasing
+ # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
+ # it needs to monotonically increasing by 1)
+ self.cg_buf_num_splits[n:].fill_(num_splits[-1])
+ num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py
index 67ea3b007ecee..faf5c132f8640 100644
--- a/vllm/v1/core/encoder_cache_manager.py
+++ b/vllm/v1/core/encoder_cache_manager.py
@@ -189,7 +189,7 @@ def compute_encoder_budget(
in the input sequence.
"""
- if not model_config.is_multimodal_model:
+ if not mm_registry.supports_multimodal_inputs(model_config):
return 0, 0
# TODO: handle encoder-decoder models once we support them.
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index d39aea1f2d116..dcb9f4dd36f52 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -437,14 +437,24 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled.
break
+ # Handles an edge case when P/D Disaggregation
+ # is used with Spec Decoding where an
+ # extra block gets allocated which
+ # creates a mismatch between the number
+ # of local and remote blocks.
+ effective_lookahead_tokens = (0 if request.num_computed_tokens
+ == 0 else
+ self.num_lookahead_tokens)
+
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
- num_lookahead_tokens=self.num_lookahead_tokens,
+ num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
)
+
if new_blocks is None:
# The request cannot be scheduled.
break
@@ -1140,6 +1150,10 @@ class Scheduler(SchedulerInterface):
# if finished_recving: add to state so we can
scheduler the request during the next step.
"""
+
+ if self.connector is not None:
+ self.connector.update_connector_output(kv_connector_output)
+
# KV Connector:: update recv and send status from last step.
for req_id in (kv_connector_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id)
diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py
index 45f450291ab63..a2706327914c5 100644
--- a/vllm/v1/engine/async_llm.py
+++ b/vllm/v1/engine/async_llm.py
@@ -27,7 +27,7 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
-from vllm.utils import Device, cdiv, deprecate_kwargs
+from vllm.utils import Device, cancel_task_threadsafe, cdiv, deprecate_kwargs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
@@ -219,8 +219,7 @@ class AsyncLLM(EngineClient):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
- if handler := getattr(self, "output_handler", None):
- handler.cancel()
+ cancel_task_threadsafe(getattr(self, "output_handler", None))
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
@@ -566,7 +565,7 @@ class AsyncLLM(EngineClient):
await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None:
- self.processor.mm_registry.reset_processor_cache()
+ self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index 78b8fe4ea676f..f92a3e43da1f2 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -21,6 +21,7 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
+from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
@@ -125,7 +126,7 @@ class EngineCore:
)
self.mm_input_cache_server = MultiModalInputCacheServer(
- vllm_config.model_config)
+ vllm_config.model_config, MULTIMODAL_REGISTRY)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py
index 4d30bb6b74466..05b4d72608963 100644
--- a/vllm/v1/engine/core_client.py
+++ b/vllm/v1/engine/core_client.py
@@ -23,7 +23,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
-from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
+from vllm.utils import (cancel_task_threadsafe, get_open_port,
+ get_open_zmq_inproc_path, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
@@ -342,10 +343,8 @@ class BackgroundResources:
if self.coordinator is not None:
self.coordinator.close()
- if self.output_queue_task is not None:
- self.output_queue_task.cancel()
- if self.stats_update_task is not None:
- self.stats_update_task.cancel()
+ cancel_task_threadsafe(self.output_queue_task)
+ cancel_task_threadsafe(self.stats_update_task)
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index efbdffbc0900d..5a00a930951cc 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -271,7 +271,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
- self.processor.mm_registry.reset_processor_cache()
+ self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()
diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py
index 279c9f0007bce..0532cda03d9a7 100644
--- a/vllm/v1/engine/mm_input_cache.py
+++ b/vllm/v1/engine/mm_input_cache.py
@@ -3,7 +3,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
-from vllm.multimodal import MultiModalKwargs
+from vllm.multimodal import MultiModalKwargs, MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of
@@ -46,10 +46,11 @@ if TYPE_CHECKING:
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
- def __init__(self, model_config: "ModelConfig") -> None:
+ def __init__(self, model_config: "ModelConfig",
+ mm_registry: MultiModalRegistry) -> None:
super().__init__()
- self.enabled = model_config.enable_mm_input_cache
+ self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata,
@@ -85,10 +86,11 @@ class MultiModalInputCacheClient:
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
- def __init__(self, model_config: "ModelConfig") -> None:
+ def __init__(self, model_config: "ModelConfig",
+ mm_registry: MultiModalRegistry) -> None:
super().__init__()
- self.enabled = model_config.enable_mm_input_cache
+ self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargs,
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 6e37ebeb87781..b9419142caf6c 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -51,7 +51,7 @@ class Processor:
mm_registry)
self.mm_input_cache_client = MultiModalInputCacheClient(
- self.model_config)
+ self.model_config, mm_registry)
@property
def mm_registry(self):
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index 3c36971fe5b49..f75d76dd978fd 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -17,10 +17,14 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
+from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
+from vllm.v1.attention.backends.rocm_aiter_fa import (
+ AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
+from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
@@ -230,11 +234,19 @@ class EagleProposer:
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
- # Currently, only FlashAttention and TreeAttention support multi-token
- # eagle spec decode. This is because the code below
- # makes assumptions about attn_metadata attributes available.
- assert isinstance(attn_metadata,
- (FlashAttentionMetadata, TreeAttentionMetadata))
+ # On ROCm, both AiterFlashAttention and TritonAttention
+ # support multi-token eagle spec decode.
+ if current_platform.is_rocm():
+ assert isinstance(
+ attn_metadata,
+ (TritonAttentionMetadata, AiterFlashAttentionMetadata,
+ FlashAttentionMetadata))
+ else:
+ # Currently, only FlashAttention and TreeAttention support
+ # multi-token eagle spec decode. This is because the code below
+ # makes assumptions about attn_metadata attributes available.
+ assert isinstance(attn_metadata,
+ (FlashAttentionMetadata, TreeAttentionMetadata))
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 08b253dcdb35c..2e1cc37b1b761 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
+ supports_eagle3,
supports_transcription)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
@@ -129,7 +130,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
- self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.is_encoder_only_model = False
self.is_multimodal_raw_input_supported = (
@@ -149,6 +149,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
+ self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
+ model_config)
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@@ -330,10 +332,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry,
max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs,
- ) if self.is_multimodal_model else None)
+ ) if self.supports_mm_inputs \
+ else None)
self.reorder_batch_threshold: Optional[int] = None
+ def _init_model_kwargs(self, num_tokens: int):
+ model_kwargs = dict[str, Any]()
+ num_reqs = self.input_batch.num_reqs
+
+ pooling_params = self.input_batch.pooling_metadata.pooling_params
+
+ num_pooling_reqs = len(pooling_params)
+
+ if num_pooling_reqs == 0:
+ return model_kwargs
+
+ assert num_pooling_reqs == num_reqs
+
+ token_type_id_requests = dict[int, Any]()
+ for i, param in enumerate(pooling_params):
+ if param.extra_kwargs is not None and \
+ (token_types := param.extra_kwargs.get(
+ "compressed_token_type_ids")) is not None:
+ token_type_id_requests[i] = token_types
+
+ if len(token_type_id_requests) == 0:
+ return model_kwargs
+
+ seq_lens = self.seq_lens[:num_reqs]
+ token_type_ids = []
+
+ for i in range(num_reqs):
+ pos = token_type_id_requests.get(i, seq_lens[i])
+ ids = (torch.arange(seq_lens[i]) >= pos).int()
+ token_type_ids.append(ids)
+
+ model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
+ device=self.device)
+ return model_kwargs
+
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
@@ -789,7 +827,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare encoder attention metadata separately
# (encoder layers are not in KV cache groups)
if self.is_encoder_only_model:
- common_attn_metadata, encoder_attn_metadata = \
+
+ per_layer_metadata = \
self._build_encoder_only_attn_metadata(
scheduler_output)
@@ -798,6 +837,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config, Attention)
for layer_name, attn_module in attention_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
+ common_attn_metadata, encoder_attn_metadata =\
+ per_layer_metadata[layer_name]
attn_metadata[layer_name] = encoder_attn_metadata
# Prepare the attention metadata for each KV cache group and make layers
@@ -1235,7 +1276,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not is_pooling_model(model):
return []
- return list(model.pooler.get_supported_tasks())
+ supported_tasks = list(model.pooler.get_supported_tasks())
+
+ if (self.scheduler_config.chunked_prefill_enabled
+ and "encode" in supported_tasks):
+ supported_tasks.remove("encode")
+
+ logger.info_once("Chunked prefill is not supported with "
+ "encode task which using ALL pooling. "
+ "Please turn off chunked prefill by "
+ "`--no-enable-chunked-prefill` before using it.")
+
+ return supported_tasks
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
@@ -1479,14 +1531,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
- if self.is_multimodal_model and get_pp_group().is_first_rank:
+ if self.supports_mm_inputs and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
@@ -1502,12 +1554,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
+ model_kwargs = self._init_model_kwargs(num_scheduled_tokens)
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
+ model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None
model_mm_kwargs = {}
if self.uses_mrope:
@@ -1546,6 +1600,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs,
device=self.device,
),
+ **model_kwargs,
)
if self.use_aux_hidden_state_outputs:
@@ -1817,7 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
@@ -1927,8 +1982,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
- self.model.set_aux_hidden_state_layers(
- self.model.get_eagle3_aux_hidden_state_layers())
+ if supports_eagle3(self.model):
+ self.model.set_aux_hidden_state_layers(
+ self.model.get_eagle3_aux_hidden_state_layers())
+ else:
+ raise RuntimeError(
+ "Model does not support EAGLE3 interface but "
+ "aux_hidden_state_outputs was requested")
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
@@ -2209,7 +2269,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
- if self.is_multimodal_model:
+ model_kwargs = self._init_model_kwargs(num_tokens)
+ if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
@@ -2250,6 +2311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs,
device=self.device,
),
+ **model_kwargs,
)
if self.use_aux_hidden_state_outputs:
@@ -2417,7 +2479,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
mm_budget = self.mm_budget
assert mm_budget is not None
@@ -2630,30 +2692,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Check if model is encoder-only
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
- attn_specs = list[AttentionSpec]()
- for attn_module in attn_layers.values():
+ attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
+ for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
- assert attn_module.sliding_window is None, "Sliding "
- "window attention is not supported for encoder-only models"
+ if attn_module.sliding_window is None:
+ attn_spec: AttentionSpec = FullAttentionSpec(
+ block_size=block_size,
+ num_kv_heads=attn_module.num_kv_heads,
+ head_size=attn_module.head_size,
+ dtype=self.kv_cache_dtype,
+ use_mla=use_mla)
+ else:
+ attn_spec = SlidingWindowSpec(
+ block_size=block_size,
+ num_kv_heads=attn_module.num_kv_heads,
+ head_size=attn_module.head_size,
+ dtype=self.kv_cache_dtype,
+ sliding_window=attn_module.sliding_window,
+ use_mla=use_mla)
+ attn_specs[attn_spec].append(layer_name)
- attn_specs.append(
- FullAttentionSpec(block_size=block_size,
- num_kv_heads=attn_module.num_kv_heads,
- head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype,
- use_mla=use_mla))
else:
raise ValueError("Expected only encoder-only layers")
if len(attn_specs) > 0:
- assert len(attn_specs) == len(attn_layers), \
+ total_layers = 0
+ for attn_spec, layer_names in attn_specs.items():
+
+ attn_backends = get_attn_backends_for_layers(layer_names)
+ total_layers += len(layer_names)
+
+ self.attn_groups.append(
+ create_attn_groups(attn_backends, attn_spec))
+ assert total_layers == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"
-
- attn_backends = get_attn_backends_for_layers(attn_layers.keys())
-
- self.attn_groups.append(
- create_attn_groups(attn_backends, attn_specs[0]))
self.is_encoder_only_model = True
def calculate_reorder_batch_threshold(self) -> None:
@@ -3018,7 +3091,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _build_encoder_only_attn_metadata(
self, scheduler_output: "SchedulerOutput") -> \
- tuple[CommonAttentionMetadata, Any]:
+ dict[str, tuple[CommonAttentionMetadata, Any]]:
"""Prepare encoder attention metadata for encoder-only models.
Args:
@@ -3035,10 +3108,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
max_num_scheduled_tokens = max(tokens)
- # Use the first attention metadata builder
- # to create encoder attention metadata
- builder = self.attn_groups[0][0].metadata_builder
-
dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32,
device=self.device)
@@ -3046,22 +3115,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32,
device=self.device)
- common_metadata = CommonAttentionMetadata(
- query_start_loc=self.query_start_loc[:num_reqs + 1],
- query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
- seq_lens=self.seq_lens[:num_reqs],
- seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
- num_computed_tokens_cpu=self.input_batch.
- num_computed_tokens_cpu_tensor[:num_reqs],
- num_reqs=num_reqs,
- num_actual_tokens=total_num_scheduled_tokens,
- max_query_len=max_num_scheduled_tokens,
- block_table_tensor=dummy_block_table,
- slot_mapping=dummy_slot_mapping,
- causal=False,
- )
+ group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
- return common_metadata, builder.build(
- common_prefix_len=0, # No cascade for encoder
- common_attn_metadata=common_metadata,
- )
+ for attn_group_list in self.attn_groups:
+
+ assert len(attn_group_list) == 1
+ attn_group = attn_group_list[0]
+
+ # Use the first attention metadata builder
+ # to create encoder attention metadata
+ builder = attn_group.metadata_builder
+
+ common_metadata = CommonAttentionMetadata(
+ query_start_loc=self.query_start_loc[:num_reqs + 1],
+ query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
+ seq_lens=self.seq_lens[:num_reqs],
+ seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
+ num_computed_tokens_cpu=self.input_batch.
+ num_computed_tokens_cpu_tensor[:num_reqs],
+ num_reqs=num_reqs,
+ num_actual_tokens=total_num_scheduled_tokens,
+ max_query_len=max_num_scheduled_tokens,
+ block_table_tensor=dummy_block_table,
+ slot_mapping=dummy_slot_mapping,
+ causal=False,
+ )
+
+ metadata = builder.build(
+ common_prefix_len=0, # No cascade for encoder
+ common_attn_metadata=common_metadata,
+ )
+
+ for layer_name in attn_group.layer_names:
+ group_metadata[layer_name] = (common_metadata, metadata)
+
+ return group_metadata
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 7fca245c1bef8..0ea23921a0806 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -21,6 +21,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
+from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
@@ -338,6 +339,10 @@ class Worker(WorkerBase):
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
+ # Warmup kernels used during model execution
+ kernel_warmup(self.get_model(),
+ max_tokens=self.scheduler_config.max_num_batched_tokens)
+
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index 81252f9b606ae..ae0219458ecfb 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -157,7 +157,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype
- self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
@@ -193,6 +192,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
+ self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
+ model_config)
# TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
@@ -293,7 +294,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry,
max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs,
- ) if self.is_multimodal_model else None)
+ ) if self.supports_mm_inputs else None)
if not self.use_spmd:
self.sample_from_logits_func = torch.compile(
@@ -744,7 +745,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
- self.block_size, self._num_slices_per_kv_cache_update_block)
+ self.block_size)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
@@ -947,7 +948,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_model_inputs(self, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor]):
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
@@ -979,7 +980,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
@@ -1137,6 +1138,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
+ kv_connector_output = None if (
+ finished_sending is None
+ and finished_recving is None) else KVConnectorOutput(
+ finished_sending=finished_sending,
+ finished_recving=finished_recving,
+ )
+
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
@@ -1145,10 +1153,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
- kv_connector_output=KVConnectorOutput(
- finished_sending=finished_sending,
- finished_recving=finished_recving,
- ))
+ kv_connector_output=kv_connector_output,
+ )
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
@@ -1230,7 +1236,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
@torch.no_grad()
def _dummy_run(self, num_tokens: int, num_reqs: int,
num_blocks: int) -> None:
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
dtype=self.dtype,
@@ -1243,8 +1249,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
- num_tokens, self.max_num_reqs, self.block_size,
- self._num_slices_per_kv_cache_update_block)
+ num_tokens, self.max_num_reqs, self.block_size)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices),
@@ -1271,7 +1276,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
_num_slices_per_kv_cache_update_block,
)
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
else:
torch._dynamo.mark_dynamic(input_ids, 0)
@@ -1305,7 +1310,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
- if not self.is_multimodal_model:
+ if not self.supports_mm_inputs:
return
# Pre-compile MM encoder for all supported data modalities.
@@ -1527,7 +1532,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens: int,
) -> None:
# Profile with multimodal encoder & encoder cache.
- if self.is_multimodal_model:
+ if self.supports_mm_inputs:
mm_budget = self.mm_budget
assert mm_budget is not None
@@ -1684,7 +1689,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
def reset_dynamo_cache(self):
- if self.is_multimodal_model:
+
+ # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
+ # since the compiled model object of the language backbone of a
+ # multimodal model needs to be extracted via `get_language_model`.
+ if self.model_config.is_multimodal_model:
compiled_model = self.model.get_language_model().model
else:
compiled_model = self.model.model
@@ -1958,17 +1967,17 @@ def copy_kv_blocks(
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
-def _get_padded_num_kv_cache_update_slices(
- num_tokens: int, max_num_reqs: int, page_size: int,
- num_slices_per_kv_cache_update_block: int) -> int:
+def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
+ page_size: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
+ # NOTE(chengjiyao): let's say R_i is the token num for i-th request,
+ # so it occupies most 2 + R_i // page_size pages. The total maximum
+ # possible number of pages needed is sum(2 + R_i // page_size), which
+ # is <= 2 * max_num_reqs + sum(R_i) // page_size
+ # = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
- padded_num_slices = (
- padded_num_slices + num_slices_per_kv_cache_update_block - 1
- ) // num_slices_per_kv_cache_update_block * \
- num_slices_per_kv_cache_update_block
return padded_num_slices
diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
index 2a7e0625b2f87..134d839252653 100644
--- a/vllm/v1/worker/xpu_worker.py
+++ b/vllm/v1/worker/xpu_worker.py
@@ -152,7 +152,7 @@ class XPUWorker(Worker):
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
- ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd")
+ ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd")
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(self.parallel_config.world_size))