Merge branch 'main' into rename_file_info_to_pkg/file

This commit is contained in:
Ning Xie 2025-12-03 19:28:15 +08:00 committed by GitHub
commit b4b79c5eba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
87 changed files with 1951 additions and 760 deletions

View File

@ -0,0 +1,73 @@
#!/usr/bin/env bash
set -euxo pipefail
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
THRESHOLD=${1:-0.25}
NUM_Q=${2:-1319}
PORT=${3:-8030}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
wait_for_server() {
local port=$1
timeout 600 bash -c '
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
sleep 1
done'
}
MODEL="deepseek-ai/DeepSeek-V2-lite"
# Set BACKENDS based on platform
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
# ROCm platform
BACKENDS=("allgather_reducescatter")
# Disable MOE padding for ROCm since it is causing eplb to fail
export VLLM_ROCM_MOE_PADDING=0
else
# Non-ROCm platform (CUDA/other)
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
kill "${SERVER_PID}" 2>/dev/null || true
for _ in {1..20}; do
kill -0 "${SERVER_PID}" 2>/dev/null || break
sleep 0.5
done
kill -9 "${SERVER_PID}" 2>/dev/null || true
fi
}
trap cleanup EXIT
for BACK in "${BACKENDS[@]}"; do
VLLM_DEEP_GEMM_WARMUP=skip \
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
--tensor-parallel-size 2 \
--data-parallel-size 2 \
--enable-expert-parallel \
--enable-eplb \
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
--trust-remote-code \
--max-model-len 2048 \
--port $PORT &
SERVER_PID=$!
wait_for_server $PORT
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json"
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
python3 - <<PY
import json; acc=json.load(open('${OUT}'))['accuracy']
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
PY
cleanup
SERVER_PID=
sleep 1
PORT=$((PORT+1))
done

View File

@ -50,6 +50,7 @@ for BACK in "${BACKENDS[@]}"; do
--data-parallel-size 2 \
--enable-expert-parallel \
--enable-eplb \
--eplb-config '{"window_size":200,"step_interval":600}' \
--trust-remote-code \
--max-model-len 2048 \
--port $PORT &

View File

@ -0,0 +1,74 @@
#!/usr/bin/env bash
set -euxo pipefail
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
THRESHOLD=${1:-0.25}
NUM_Q=${2:-1319}
PORT=${3:-8040}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
wait_for_server() {
local port=$1
timeout 600 bash -c '
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
sleep 1
done'
}
MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct"
# Set BACKENDS based on platform
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
# ROCm platform
BACKENDS=("allgather_reducescatter")
# Disable MOE padding for ROCm since it is causing eplb to fail
export VLLM_ROCM_MOE_PADDING=0
else
# Non-ROCm platform (CUDA/other)
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
kill "${SERVER_PID}" 2>/dev/null || true
for _ in {1..20}; do
kill -0 "${SERVER_PID}" 2>/dev/null || break
sleep 0.5
done
kill -9 "${SERVER_PID}" 2>/dev/null || true
fi
}
trap cleanup EXIT
for BACK in "${BACKENDS[@]}"; do
VLLM_DEEP_GEMM_WARMUP=skip \
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--enable-eplb \
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
--trust-remote-code \
--max-model-len 2048 \
--gpu-memory-utilization 0.9 \
--port $PORT &
SERVER_PID=$!
wait_for_server $PORT
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
OUT="${OUT_DIR}/${TAG}_${BACK}.json"
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
python3 - <<PY
import json; acc=json.load(open('${OUT}'))['accuracy']
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
PY
cleanup
SERVER_PID=
sleep 1
PORT=$((PORT+1))
done

View File

@ -1373,4 +1373,22 @@ steps:
num_gpus: 2
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
- label: DeepSeek V2-Lite Async EPLB Accuracy
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040

View File

@ -65,7 +65,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
# Centralized v1 package - copied to both test and final stages
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
# -----------------------
@ -98,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system hf_transfer
ENV HF_HUB_ENABLE_HF_TRANSFER=1
# Copy in the v1 package
# Copy in the v1 package (for python-only install test group)
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# Source code is used in the `python_only_compile.sh` test
@ -130,9 +129,6 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
&& pip uninstall -y vllm \
&& uv pip install --system *.whl
# Copy in the v1 package
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
ARG COMMON_WORKDIR
# Copy over the benchmark scripts as well

View File

@ -108,6 +108,116 @@ networks.
Consult your operating system or application platform documentation for specific
firewall configuration instructions.
## API Key Authentication Limitations
### Overview
The `--api-key` flag (or `VLLM_API_KEY` environment variable) provides authentication for vLLM's HTTP server, but **only for OpenAI-compatible API endpoints under the `/v1` path prefix**. Many other sensitive endpoints are exposed on the same HTTP server without any authentication enforcement.
**Important:** Do not rely exclusively on `--api-key` for securing access to vLLM. Additional security measures are required for production deployments.
### Protected Endpoints (Require API Key)
When `--api-key` is configured, the following `/v1` endpoints require Bearer token authentication:
- `/v1/models` - List available models
- `/v1/chat/completions` - Chat completions
- `/v1/completions` - Text completions
- `/v1/embeddings` - Generate embeddings
- `/v1/audio/transcriptions` - Audio transcription
- `/v1/audio/translations` - Audio translation
- `/v1/messages` - Anthropic-compatible messages API
- `/v1/responses` - Response management
- `/v1/score` - Scoring API
- `/v1/rerank` - Reranking API
### Unprotected Endpoints (No API Key Required)
The following endpoints **do not require authentication** even when `--api-key` is configured:
**Inference endpoints:**
- `/invocations` - SageMaker-compatible endpoint (routes to the same inference functions as `/v1` endpoints)
- `/inference/v1/generate` - Generate completions
- `/pooling` - Pooling API
- `/classify` - Classification API
- `/score` - Scoring API (non-`/v1` variant)
- `/rerank` - Reranking API (non-`/v1` variant)
**Operational control endpoints (always enabled):**
- `/pause` - Pause generation (causes denial of service)
- `/resume` - Resume generation
- `/scale_elastic_ep` - Trigger scaling operations
**Utility endpoints:**
- `/tokenize` - Tokenize text
- `/detokenize` - Detokenize tokens
- `/health` - Health check
- `/ping` - SageMaker health check
- `/version` - Version information
- `/load` - Server load metrics
**Tokenizer information endpoint (only when `--enable-tokenizer-info-endpoint` is set):**
This endpoint is **only available when the `--enable-tokenizer-info-endpoint` flag is set**. It may expose sensitive information such as chat templates and tokenizer configuration:
- `/tokenizer_info` - Get comprehensive tokenizer information including chat templates and configuration
**Development endpoints (only when `VLLM_SERVER_DEV_MODE=1`):**
These endpoints are **only available when the environment variable `VLLM_SERVER_DEV_MODE` is set to `1`**. They are intended for development and debugging purposes and should never be enabled in production:
- `/server_info` - Get detailed server configuration
- `/reset_prefix_cache` - Reset prefix cache (can disrupt service)
- `/reset_mm_cache` - Reset multimodal cache (can disrupt service)
- `/sleep` - Put engine to sleep (causes denial of service)
- `/wake_up` - Wake engine from sleep
- `/is_sleeping` - Check if engine is sleeping
- `/collective_rpc` - Execute arbitrary RPC methods on the engine (extremely dangerous)
**Profiler endpoints (only when `VLLM_TORCH_PROFILER_DIR` or `VLLM_TORCH_CUDA_PROFILE` are set):**
These endpoints are only available when profiling is enabled and should only be used for local development:
- `/start_profile` - Start PyTorch profiler
- `/stop_profile` - Stop PyTorch profiler
**Note:** The `/invocations` endpoint is particularly concerning as it provides unauthenticated access to the same inference capabilities as the protected `/v1` endpoints.
### Security Implications
An attacker who can reach the vLLM HTTP server can:
1. **Bypass authentication** by using non-`/v1` endpoints like `/invocations`, `/inference/v1/generate`, `/pooling`, `/classify`, `/score`, or `/rerank` to run arbitrary inference without credentials
2. **Cause denial of service** by calling `/pause` or `/scale_elastic_ep` without a token
3. **Access operational controls** to manipulate server state (e.g., pausing generation)
4. **If `--enable-tokenizer-info-endpoint` is set:** Access sensitive tokenizer configuration including chat templates, which may reveal prompt engineering strategies or other implementation details
5. **If `VLLM_SERVER_DEV_MODE=1` is set:** Execute arbitrary RPC commands via `/collective_rpc`, reset caches, put the engine to sleep, and access detailed server configuration
### Recommended Security Practices
#### 1. Minimize Exposed Endpoints
**CRITICAL:** Never set `VLLM_SERVER_DEV_MODE=1` in production environments. Development endpoints expose extremely dangerous functionality including:
- Arbitrary RPC execution via `/collective_rpc`
- Cache manipulation that can disrupt service
- Detailed server configuration disclosure
Similarly, never enable profiler endpoints (`VLLM_TORCH_PROFILER_DIR` or `VLLM_TORCH_CUDA_PROFILE`) in production.
**Be cautious with `--enable-tokenizer-info-endpoint`:** Only enable the `/tokenizer_info` endpoint if you need to expose tokenizer configuration information. This endpoint reveals chat templates and tokenizer settings that may contain sensitive implementation details or prompt engineering strategies.
#### 2. Deploy Behind a Reverse Proxy
The most effective approach is to deploy vLLM behind a reverse proxy (such as nginx, Envoy, or a Kubernetes Gateway) that:
- Explicitly allowlists only the endpoints you want to expose to end users
- Blocks all other endpoints, including the unauthenticated inference and operational control endpoints
- Implements additional authentication, rate limiting, and logging at the proxy layer
## Reporting Security Vulnerabilities
If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md).

View File

@ -309,6 +309,28 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
)
# HunyuanOCR
def load_hunyuan_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "tencent/HunyuanOCR"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholder = (
"<hy_place▁holder▁no▁100><hy_place▁holder▁no▁102><hy_place▁holder▁no▁101>" # noqa: E501
) * len(image_urls)
prompt = f"<hy_begin▁of▁sentence>{placeholder}{question}<hy_User>"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_hyperclovax_seed_vision(
question: str, image_urls: list[str]
) -> ModelRequestData:
@ -1322,6 +1344,7 @@ model_example_map = {
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"hunyuan_vl": load_hunyuan_vl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
"idefics3": load_idefics3,
"interns1": load_interns1,

View File

@ -70,8 +70,8 @@ torchgeo==0.7.0
mteb==2.1.2
# Data processing
xgrammar @ git+https://github.com/mlc-ai/xgrammar.git@eafd4db51b78acc64b3f0764ef27dfd206c28628
# Test async scheduling
xgrammar==0.1.27
# Test async scheduling
# Utilities
num2words==0.5.14

View File

@ -326,7 +326,7 @@ def async_tp_pass_on_test_model(
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_async_tp=True,
fuse_gemm_comms=True,
),
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
@ -413,7 +413,7 @@ def test_async_tp_pass_correctness(
"mode": CompilationMode.VLLM_COMPILE,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {"enable_async_tp": async_tp_enabled},
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
}
async_tp_args = [

View File

@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model(
)
)
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True
fuse_allreduce_rms=True, eliminate_noops=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path

View File

@ -192,7 +192,7 @@ def test_attn_quant(
splitting_ops=splitting_ops,
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_fi_allreduce_fusion=True,
fuse_attn_quant=True,
eliminate_noops=True,
fuse_allreduce_rms=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp(
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
fuse_attn_quant=True,
eliminate_noops=True,
enable_sp=True,
fuse_gemm_comms=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},

View File

@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
]
def ops_in_model(self):
if self.vllm_config.compilation_config.pass_config.enable_fusion:
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
elif RMSNorm.enabled():
return [
@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
fuse_norm_quant: bool,
dynamic: bool,
):
num_processes = 2
@ -211,7 +211,7 @@ def test_sequence_parallelism_pass(
seq_len,
hidden_size,
dtype,
enable_fusion,
fuse_norm_quant,
dynamic,
),
nprocs=nprocs,
@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
fuse_norm_quant: bool,
dynamic: bool,
):
current_platform.seed_everything(0)
@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
custom_ops=custom_ops_list,
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
enable_sp=True,
fuse_norm_quant=fuse_norm_quant,
eliminate_noops=True,
),
) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda"))
@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
sequence_parallelism_pass,
]
if enable_fusion:
if fuse_norm_quant:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)

View File

@ -122,7 +122,9 @@ def test_full_graph(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
pass_config=PassConfig(
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
),
),
*model_info,
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import logging
from contextlib import nullcontext
from unittest.mock import patch
@ -10,8 +11,9 @@ from pydantic import ValidationError
from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer
@ -191,7 +193,7 @@ def test_splitting_ops_dynamic():
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
@ -206,7 +208,7 @@ def test_splitting_ops_dynamic():
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
@ -219,7 +221,7 @@ def test_splitting_ops_dynamic():
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
@ -227,7 +229,7 @@ def test_splitting_ops_dynamic():
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# fuse_attn_quant is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
@ -301,7 +303,7 @@ def test_should_split():
"cudagraph_capture_sizes",
"max_cudagraph_capture_size",
"tp_size",
"enable_sequence_parallelism",
"enable_sp",
"max_num_batched_tokens",
"cudagraph_mode",
"expected_max_size",
@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init(
cudagraph_capture_sizes,
max_cudagraph_capture_size,
tp_size,
enable_sequence_parallelism,
enable_sp,
max_num_batched_tokens,
cudagraph_mode,
expected_max_size,
@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init(
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
pass_config={
"enable_sequence_parallelism": enable_sequence_parallelism,
"enable_fusion": True,
"enable_noop": True,
},
pass_config=PassConfig(
enable_sp=enable_sp,
fuse_norm_quant=True,
fuse_act_quant=True,
eliminate_noops=True,
),
cudagraph_mode=cudagraph_mode,
)
engine_args = EngineArgs(
@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init(
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)
def test_pass_config_deprecation(caplog_vllm):
caplog_vllm.set_level(logging.WARNING)
# Clear cache to ensure warnings are re-issued
_print_warning_once.cache_clear()
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
caplog_vllm.clear()
config = PassConfig(enable_fusion=True)
assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True
assert config.enable_fusion is None
# Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is None
# Test enable_noop -> eliminate_noops
caplog_vllm.clear()
config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True
assert config.enable_noop is None
# Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True
assert config.enable_sequence_parallelism is None
# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear()
config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True
assert config.enable_async_tp is None
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is None

View File

@ -223,7 +223,11 @@ def test_fix_functionalization(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
pass_config=PassConfig(
fuse_norm_quant=do_fusion,
fuse_act_quant=do_fusion,
eliminate_noops=True,
),
),
)

View File

@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
pass_config=PassConfig(
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
),
),
)
with vllm.config.set_current_vllm_config(vllm_config):

View File

@ -373,7 +373,7 @@ def test_attention_quant_pattern(
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True
fuse_attn_quant=True, eliminate_noops=True
)
with (
set_current_vllm_config(vllm_config),

View File

@ -51,7 +51,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_noop=True),
pass_config=PassConfig(eliminate_noops=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
@ -99,7 +99,7 @@ def test_non_noop_slice_preserved():
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_noop=True),
pass_config=PassConfig(eliminate_noops=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):

View File

@ -64,8 +64,11 @@ def test_pass_manager_uuid(callable):
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.enable_fusion = (
not config2.compilation_config.pass_config.enable_fusion
config2.compilation_config.pass_config.fuse_norm_quant = (
not config2.compilation_config.pass_config.fuse_norm_quant
)
config2.compilation_config.pass_config.fuse_act_quant = (
not config2.compilation_config.pass_config.fuse_act_quant
)
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)

View File

@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion(
custom_ops=custom_ops,
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
enable_noop=True,
eliminate_noops=True,
),
),
)

View File

@ -168,7 +168,7 @@ def test_fusion_silu_and_mul_quant(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
),
)

View File

@ -32,7 +32,8 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
enable_fusion: bool
fuse_norm_quant: bool
fuse_act_quant: bool
eager_mode: bool
chunked_prefill: bool
@ -66,7 +67,8 @@ class SPTestSettings:
ParallelSetup(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
enable_fusion=False,
fuse_norm_quant=False,
fuse_act_quant=False,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
@ -97,7 +99,8 @@ class SPTestSettings:
ParallelSetup(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
enable_fusion=False,
fuse_norm_quant=False,
fuse_act_quant=False,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
@ -126,7 +129,8 @@ class SPTestSettings:
ParallelSetup(
tp_size=tp_base,
pp_size=pp_base,
enable_fusion=fusion_val,
fuse_norm_quant=fusion_val,
fuse_act_quant=fusion_val,
eager_mode=True,
chunked_prefill=False,
)
@ -162,7 +166,7 @@ def _compare_sp(
test_options: SPTestOptions,
num_gpus_available: int,
use_inductor_graph_partition: bool,
enable_async_tp: bool,
fuse_gemm_comms: bool,
*,
method: Literal["generate", "encode"],
is_multimodal: bool,
@ -170,7 +174,8 @@ def _compare_sp(
(
tp_size,
pp_size,
enable_fusion,
fuse_norm_quant,
fuse_act_quant,
eager_mode,
chunked_prefill,
) = parallel_setup
@ -248,10 +253,11 @@ def _compare_sp(
"mode": CompilationMode.VLLM_COMPILE,
"compile_sizes": [4, 8],
"pass_config": {
"enable_sequence_parallelism": True,
"enable_async_tp": enable_async_tp,
"enable_fusion": enable_fusion,
"enable_noop": True,
"enable_sp": True,
"fuse_gemm_comms": fuse_gemm_comms,
"fuse_norm_quant": fuse_norm_quant,
"fuse_act_quant": fuse_act_quant,
"eliminate_noops": True,
},
"use_inductor_graph_partition": use_inductor_graph_partition,
}
@ -309,7 +315,7 @@ SP_TEST_MODELS = [
],
)
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
@pytest.mark.parametrize("fuse_gemm_comms", [False]) # TODO: enable async TP
@create_new_process_for_each_test()
def test_tp_sp_generation(
model_id: str,
@ -319,7 +325,7 @@ def test_tp_sp_generation(
test_options: SPTestOptions,
num_gpus_available,
use_inductor_graph_partition: bool,
enable_async_tp: bool,
fuse_gemm_comms: bool,
):
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
@ -328,7 +334,7 @@ def test_tp_sp_generation(
if (
"fp8" in model_id.lower()
and current_platform.get_device_capability() < (9, 0)
and (not enable_async_tp)
and (not fuse_gemm_comms)
):
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
@ -340,7 +346,7 @@ def test_tp_sp_generation(
test_options,
num_gpus_available,
use_inductor_graph_partition,
enable_async_tp=enable_async_tp,
fuse_gemm_comms=fuse_gemm_comms,
method="generate",
is_multimodal=False,
)

View File

@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
@pytest.mark.asyncio
async def test_health_check_engine_dead_error():
# Import the health function directly to test it in isolation
from vllm.entrypoints.openai.api_server import health
from vllm.entrypoints.serve.instrumentator.health import health
# Create a mock request that simulates what FastAPI would provide
mock_request = Mock(spec=Request)

View File

@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
assert response.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_enable_response_messages(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="Hello?",
extra_body={"enable_response_messages": True},
)
assert response.status == "completed"
assert response.input_messages[0]["type"] == "raw_message_tokens"
assert type(response.input_messages[0]["message"]) is str
assert len(response.input_messages[0]["message"]) > 10
assert type(response.input_messages[0]["tokens"][0]) is int
assert type(response.output_messages[0]["message"]) is str
assert len(response.output_messages[0]["message"]) > 10
assert type(response.output_messages[0]["tokens"][0]) is int
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_reasoning_item(client: OpenAI, model_name: str):

View File

@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM tests."""
import torch
from vllm.platforms import current_platform
def pytest_configure(config):
"""Disable Flash/MemEfficient SDP on ROCm to avoid HF
Transformers accuracy issues.
"""
if not current_platform.is_rocm():
return
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

View File

@ -137,7 +137,7 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_omni": VLMTestInfo(
@ -152,7 +152,7 @@ VLM_TEST_SETTINGS = {
auto_cls=AutoModelForTextToWaveform,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen3_vl": VLMTestInfo(
@ -173,7 +173,7 @@ VLM_TEST_SETTINGS = {
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[
pytest.mark.core_model,
],
@ -350,7 +350,7 @@ VLM_TEST_SETTINGS = {
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
stop_str=["<end▁of▁sentence>", "<begin▁of▁sentence>"],
image_size_factors=[(), (1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
image_size_factors=[(1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
),
"fuyu": VLMTestInfo(
models=["adept/fuyu-8b"],
@ -707,7 +707,7 @@ VLM_TEST_SETTINGS = {
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForCausalLM,
image_size_factors=[(), (0.25,)],
image_size_factors=[(0.25,)],
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) == Version("4.57.3"),
@ -760,7 +760,7 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.cpu_model],
),
"skywork_r1v": VLMTestInfo(
@ -812,7 +812,7 @@ VLM_TEST_SETTINGS = {
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.skip("Model initialization hangs")],
),
### Tensor parallel / multi-gpu broadcast tests

View File

@ -62,6 +62,65 @@ def get_filtered_test_settings(
return matching_tests
def get_model_type_cases(
model_type: str,
test_info: VLMTestInfo,
test_type: VLMTestType,
):
# Ensure that something is wrapped as an iterable it's not already
ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,)
# This is essentially the same as nesting a bunch of mark.parametrize
# decorators, but we do it programmatically to allow overrides for on
# a per-model basis, while still being able to execute each of these
# as individual test cases in pytest.
iter_kwargs = OrderedDict(
[
("model", ensure_wrapped(test_info.models)),
("max_tokens", ensure_wrapped(test_info.max_tokens)),
("num_logprobs", ensure_wrapped(test_info.num_logprobs)),
("dtype", ensure_wrapped(test_info.dtype)),
(
"distributed_executor_backend",
ensure_wrapped(test_info.distributed_executor_backend),
),
]
)
# num_frames is video only
if test_type == VLMTestType.VIDEO:
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
test_info.needs_video_metadata
)
# No sizes passed for custom inputs, since inputs are directly provided
if test_type not in (
VLMTestType.CUSTOM_INPUTS,
VLMTestType.AUDIO,
):
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
if wrapped_sizes is None:
raise ValueError(f"Sizes must be set for test type {test_type}")
iter_kwargs["size_wrapper"] = wrapped_sizes
# Otherwise expand the custom test options instead
elif test_type == VLMTestType.CUSTOM_INPUTS:
if test_info.custom_test_opts is None:
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
# Wrap all model cases in a pytest parameter & pass marks through
return [
pytest.param(
model_type,
ExpandableVLMTestArgs(**{k: v for k, v in zip(iter_kwargs.keys(), case)}),
marks=test_info.marks if test_info.marks is not None else [],
)
for case in list(itertools.product(*iter_kwargs.values()))
]
def get_parametrized_options(
test_settings: dict[str, VLMTestInfo],
test_type: VLMTestType,
@ -76,64 +135,11 @@ def get_parametrized_options(
test_settings, test_type, create_new_process_for_each_test
)
# Ensure that something is wrapped as an iterable it's not already
ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,)
def get_model_type_cases(model_type: str, test_info: VLMTestInfo):
# This is essentially the same as nesting a bunch of mark.parametrize
# decorators, but we do it programmatically to allow overrides for on
# a per-model basis, while still being able to execute each of these
# as individual test cases in pytest.
iter_kwargs = OrderedDict(
[
("model", ensure_wrapped(test_info.models)),
("max_tokens", ensure_wrapped(test_info.max_tokens)),
("num_logprobs", ensure_wrapped(test_info.num_logprobs)),
("dtype", ensure_wrapped(test_info.dtype)),
(
"distributed_executor_backend",
ensure_wrapped(test_info.distributed_executor_backend),
),
]
)
# num_frames is video only
if test_type == VLMTestType.VIDEO:
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
test_info.needs_video_metadata
)
# No sizes passed for custom inputs, since inputs are directly provided
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
if wrapped_sizes is None:
raise ValueError(f"Sizes must be set for test type {test_type}")
iter_kwargs["size_wrapper"] = wrapped_sizes
# Otherwise expand the custom test options instead
elif test_type == VLMTestType.CUSTOM_INPUTS:
if test_info.custom_test_opts is None:
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
# Wrap all model cases in a pytest parameter & pass marks through
return [
pytest.param(
model_type,
ExpandableVLMTestArgs(
**{k: v for k, v in zip(iter_kwargs.keys(), case)}
),
marks=test_info.marks if test_info.marks is not None else [],
)
for case in list(itertools.product(*iter_kwargs.values()))
]
# Get a list per model type, where each entry contains a tuple of all of
# that model type's cases, then flatten them into the top level so that
# we can consume them in one mark.parametrize call.
cases_by_model_type = [
get_model_type_cases(model_type, test_info)
get_model_type_cases(model_type, test_info, test_type)
for model_type, test_info in matching_tests.items()
]
return list(itertools.chain(*cases_by_model_type))

View File

@ -50,8 +50,8 @@ MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PL
VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?"
IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)]
EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)]
IMAGE_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)]
EMBEDDING_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0)]
RunnerOutput = tuple[list[int], str, SampleLogprobs | None]

View File

@ -47,6 +47,12 @@ QWEN2_CONFIG = GGUFTestConfig(
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)
QWEN3_CONFIG = GGUFTestConfig(
original_model="Qwen/Qwen3-0.6B",
gguf_repo="unsloth/Qwen3-0.6B-GGUF",
gguf_filename="Qwen3-0.6B-BF16.gguf",
)
PHI3_CONFIG = GGUFTestConfig(
original_model="microsoft/Phi-3.5-mini-instruct",
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
@ -87,6 +93,7 @@ GEMMA3_CONFIG = GGUFTestConfig(
MODELS = [
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
QWEN2_CONFIG,
QWEN3_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,

View File

@ -1023,17 +1023,17 @@ def test_vllm_config_explicit_overrides():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
# Explicit pass config flags to override defaults
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig(
optimization_level=OptimizationLevel.O0,
compilation_config=compilation_config,
)
assert config.compilation_config.pass_config.enable_noop is True
assert config.compilation_config.pass_config.enable_attn_fusion is True
assert config.compilation_config.pass_config.eliminate_noops is True
assert config.compilation_config.pass_config.fuse_attn_quant is True
# Explicit cudagraph mode override on quantized model at O2
pass_config = PassConfig(enable_async_tp=True)
pass_config = PassConfig(fuse_gemm_comms=True)
compilation_config = CompilationConfig(
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
)
@ -1043,7 +1043,7 @@ def test_vllm_config_explicit_overrides():
compilation_config=compilation_config,
)
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert config.compilation_config.pass_config.enable_async_tp is True
assert config.compilation_config.pass_config.fuse_gemm_comms is True
# Mode should still use default for O2
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
@ -1093,7 +1093,7 @@ def test_vllm_config_explicit_overrides():
compilation_config=compilation_config,
)
# Explicit override should be respected
assert config.compilation_config.pass_config.enable_noop is False
assert config.compilation_config.pass_config.eliminate_noops is False
# Other fields should still use defaults
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class CUDAGraphStat:
num_unpadded_tokens: int
num_padded_tokens: int
num_paddings: int
runtime_mode: str
class CUDAGraphLogging:
"""Aggregate and log cudagraph metrics"""
COLUMN_HEADERS = [
"Unpadded Tokens",
"Padded Tokens",
"Num Paddings",
"Runtime Mode",
"Count",
]
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
self.settings_header = (
"**CUDAGraph Config Settings:**\n\n"
f"- Mode: {self.cg_mode}\n"
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
"**CUDAGraph Stats:**\n\n"
)
def reset(self):
self.stats = []
def observe(self, cudagraph_stat: CUDAGraphStat):
self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str:
stats_counts = Counter(self.stats)
# Convert stats to rows of strings, in descending order of observed frequencies
rows = []
for stat, count in sorted(
stats_counts.items(), key=lambda item: item[1], reverse=True
):
rows.append(
[
str(stat.num_unpadded_tokens),
str(stat.num_padded_tokens),
str(stat.num_paddings),
stat.runtime_mode,
str(count),
]
)
# Calculate column widths (max of header and data)
col_widths = []
for i, header_text in enumerate(self.COLUMN_HEADERS):
max_width = len(header_text)
for row in rows:
max_width = max(max_width, len(row[i]))
col_widths.append(max_width)
table_header_list = [
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
]
table_header = "| " + " | ".join(table_header_list) + " |\n"
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
# Create data rows with proper alignment
data_rows = []
for row in rows:
formatted_row = [
str(val).ljust(width) for val, width in zip(row, col_widths)
]
data_rows.append("| " + " | ".join(formatted_row) + " |")
return (
self.settings_header
+ table_header
+ table_separator
+ "\n".join(data_rows)
+ "\n"
)
def log(self, log_fn=logger.info):
if not self.stats:
return
log_fn(self.generate_metric_table())
self.reset()
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor

View File

@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.

View File

@ -75,8 +75,8 @@ def find_op_nodes(
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:

View File

@ -92,22 +92,23 @@ class PostGradPassManager(CustomGraphPass):
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.enable_noop:
if self.pass_config.eliminate_noops:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sequence_parallelism:
if self.pass_config.enable_sp:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
if self.pass_config.fuse_gemm_comms:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
if self.pass_config.fuse_allreduce_rms:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_fusion:
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion:
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:

View File

@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.config.utils import config, handle_deprecated
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
@ -105,18 +105,43 @@ class PassConfig:
improper state.
"""
# New flags
fuse_norm_quant: bool = Field(default=None)
"""Fuse the custom RMSNorm + quant ops."""
fuse_act_quant: bool = Field(default=None)
"""Fuse the custom SiluMul + quant ops."""
fuse_attn_quant: bool = Field(default=None)
"""Fuse the custom attention + quant ops."""
eliminate_noops: bool = Field(default=None)
"""Eliminate no-op ops."""
enable_sp: bool = Field(default=None)
"""Enable sequence parallelism."""
fuse_gemm_comms: bool = Field(default=None)
"""Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion: bool = Field(default=None)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion: bool = Field(default=None)
"""Whether to enable the custom attention+quant fusion pass."""
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop: bool = Field(default=None)
"""Whether to enable the custom no-op elimination pass."""
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism: bool = Field(default=None)
"""Whether to enable sequence parallelism."""
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp: bool = Field(default=None)
"""Whether to enable async TP."""
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Whether to enable flashinfer allreduce fusion."""
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
@ -136,7 +161,7 @@ class PassConfig:
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
"""Enable fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@ -174,6 +199,13 @@ class PassConfig:
return InductorPass.hash_dict(asdict(self))
@field_validator(
"fuse_norm_quant",
"fuse_act_quant",
"fuse_attn_quant",
"eliminate_noops",
"enable_sp",
"fuse_gemm_comms",
"fuse_allreduce_rms",
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
@ -190,18 +222,71 @@ class PassConfig:
return handler(value)
def __post_init__(self) -> None:
if not self.enable_noop:
if self.enable_fusion:
# Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated(
self,
"enable_fusion",
["fuse_norm_quant", "fuse_act_quant"],
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_attn_fusion",
"fuse_attn_quant",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_sequence_parallelism",
"enable_sp",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_async_tp",
"fuse_gemm_comms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_fi_allreduce_fusion",
"fuse_allreduce_rms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_noop",
"eliminate_noops",
"v0.13.0 or v1.0.0, whichever is sooner",
)
# Force old flags to None to ensure they are not used
self.enable_fusion = None
self.enable_attn_fusion = None
self.enable_noop = None
self.enable_sequence_parallelism = None
self.enable_async_tp = None
self.enable_fi_allreduce_fusion = None
if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
)
if self.enable_attn_fusion:
if self.fuse_attn_quant:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Attention + quant (fp8) fusion might not work"
)
if self.enable_fi_allreduce_fusion:
if self.fuse_allreduce_rms:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
@ -873,7 +958,7 @@ class CompilationConfig:
self.set_splitting_ops_for_inductor_graph_partition()
return
if self.pass_config.enable_attn_fusion:
if self.pass_config.fuse_attn_quant:
# here use_inductor_graph_partition is False
self.set_splitting_ops_for_attn_fusion()
return
@ -915,12 +1000,12 @@ class CompilationConfig:
self.splitting_ops = list(self._attention_ops)
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
assert self.pass_config.fuse_attn_quant
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"fuse_attn_quant is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
@ -931,8 +1016,7 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
"when enable_attn_fusion is True"
"attention ops should not be in splitting_ops when fuse_attn_quant is True"
)
def splitting_ops_contain_attention(self) -> bool:
@ -1008,7 +1092,7 @@ class CompilationConfig:
self, uniform_decode_query_len: int, tensor_parallel_size: int
):
multiple_of = uniform_decode_query_len
if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
if tensor_parallel_size > 1 and self.pass_config.enable_sp:
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
if (
multiple_of % uniform_decode_query_len != 0

View File

@ -55,6 +55,10 @@ class ObservabilityConfig:
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
cudagraph_metrics: bool = False
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""

View File

@ -19,6 +19,10 @@ import torch
from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable
from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:
@ -293,3 +297,28 @@ def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, ob
def hash_factors(items: dict[str, object]) -> str:
"""Return a SHA-256 hex digest of the canonical items structure."""
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
def handle_deprecated(
config: ConfigT,
old_name: str,
new_name_or_names: str | list[str],
removal_version: str,
) -> None:
old_val = getattr(config, old_name)
if old_val is None:
return
if isinstance(new_name_or_names, str):
new_names = [new_name_or_names]
else:
new_names = new_name_or_names
msg = (
f"{old_name} is deprecated and will be removed in {removal_version}. "
f"Use {', '.join(new_names)} instead."
)
logger.warning(msg)
for new_name in new_names:
setattr(config, new_name, old_val)

View File

@ -83,22 +83,33 @@ IS_DENSE = False
# See https://github.com/vllm-project/vllm/issues/25689.
def enable_fusion(cfg: "VllmConfig") -> bool:
"""Returns True if RMS norm or quant FP8 is enabled."""
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
"""Enable if either RMS norm or quant FP8 custom op is active;
otherwise Inductor handles fusion."""
return cfg.compilation_config.is_custom_op_enabled(
"rms_norm"
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
def enable_act_fusion(cfg: "VllmConfig") -> bool:
"""Enable if either SiLU+Mul or quant FP8 custom op is active;
otherwise Inductor handles fusion."""
return cfg.compilation_config.is_custom_op_enabled(
"silu_and_mul"
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
OPTIMIZATION_LEVEL_00 = {
"compilation_config": {
"pass_config": {
"enable_noop": False,
"enable_fusion": False,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": False,
"enable_sequence_parallelism": False,
"enable_async_tp": False,
"eliminate_noops": False,
"fuse_norm_quant": False,
"fuse_act_quant": False,
"fuse_allreduce_rms": False,
"fuse_attn_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
@ -107,12 +118,13 @@ OPTIMIZATION_LEVEL_00 = {
OPTIMIZATION_LEVEL_01 = {
"compilation_config": {
"pass_config": {
"enable_noop": True,
"enable_fusion": enable_fusion,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": False,
"enable_sequence_parallelism": False,
"enable_async_tp": False,
"eliminate_noops": True,
"fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False,
"fuse_attn_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
@ -121,12 +133,13 @@ OPTIMIZATION_LEVEL_01 = {
OPTIMIZATION_LEVEL_02 = {
"compilation_config": {
"pass_config": {
"enable_noop": True,
"enable_fusion": enable_fusion,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": IS_QUANTIZED,
"enable_sequence_parallelism": IS_DENSE,
"enable_async_tp": IS_DENSE,
"eliminate_noops": True,
"fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False,
"fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
@ -135,12 +148,13 @@ OPTIMIZATION_LEVEL_02 = {
OPTIMIZATION_LEVEL_03 = {
"compilation_config": {
"pass_config": {
"enable_noop": True,
"enable_fusion": enable_fusion,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": IS_QUANTIZED,
"enable_sequence_parallelism": IS_DENSE,
"enable_async_tp": IS_DENSE,
"eliminate_noops": True,
"fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False,
"fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
@ -645,9 +659,9 @@ class VllmConfig:
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True
if self.compilation_config.pass_config.enable_sequence_parallelism:
if self.compilation_config.pass_config.fuse_gemm_comms:
self.compilation_config.pass_config.enable_sp = True
if self.compilation_config.pass_config.enable_sp:
if "-rms_norm" in self.compilation_config.custom_ops:
logger.warning(
"RMS norm force disabled, sequence parallelism might break"
@ -797,7 +811,7 @@ class VllmConfig:
# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1()
if self.compilation_config.pass_config.enable_sequence_parallelism:
if self.compilation_config.pass_config.enable_sp:
# With pipeline parallelism or dynamo partitioning,
# native rms norm tracing errors due to incorrect residual shape.
# Use custom rms norm to unblock. In the future,
@ -1062,7 +1076,7 @@ class VllmConfig:
if (
self.parallel_config.tensor_parallel_size > 1
and self.compilation_config.pass_config.enable_sequence_parallelism
and self.compilation_config.pass_config.enable_sp
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes

View File

@ -322,9 +322,6 @@ async def transfer_layer(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
num_local_experts=num_local_physical_experts,

View File

@ -518,6 +518,7 @@ class EngineArgs:
kv_cache_metrics_sample: float = get_field(
ObservabilityConfig, "kv_cache_metrics_sample"
)
cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
@ -1021,6 +1022,10 @@ class EngineArgs:
"--kv-cache-metrics-sample",
**observability_kwargs["kv_cache_metrics_sample"],
)
observability_group.add_argument(
"--cudagraph-metrics",
**observability_kwargs["cudagraph_metrics"],
)
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
@ -1698,6 +1703,7 @@ class EngineArgs:
collect_detailed_traces=self.collect_detailed_traces,
kv_cache_metrics=self.kv_cache_metrics,
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics,
)
# Compilation config overrides

View File

@ -118,6 +118,7 @@ async def init_app(
)
)
app.state.engine_client = engine
app.state.args = args
return app

View File

@ -109,6 +109,10 @@ def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser
help=(
"API key for OpenAI services. If provided, this api key "
"will overwrite the api key obtained through environment variables."
" It is important to note that this option only applies to the "
"OpenAI-compatible API endpoints and NOT other endpoints that may "
"be present in the server. See the security guide in the vLLM docs "
"for more details."
),
)
return parser

View File

@ -23,6 +23,7 @@ from vllm.entrypoints.openai.parser.responses_parser import (
)
from vllm.entrypoints.openai.protocol import (
ResponseInputOutputItem,
ResponseRawMessageAndToken,
ResponsesRequest,
)
from vllm.entrypoints.responses_utils import construct_tool_dicts
@ -148,6 +149,8 @@ def _create_json_parse_error_messages(
class SimpleContext(ConversationContext):
"""This is a context that cannot handle MCP tool calls"""
def __init__(self):
self.last_output = None
self.num_prompt_tokens = 0
@ -158,6 +161,9 @@ class SimpleContext(ConversationContext):
# not implemented yet for SimpleContext
self.all_turn_metrics = []
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
def append_output(self, output) -> None:
self.last_output = output
if not isinstance(output, RequestOutput):
@ -166,6 +172,22 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
if len(self.input_messages) == 0:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
self.input_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
self.output_messages.append(
ResponseRawMessageAndToken(
message=output.outputs[0].text,
tokens=output.outputs[0].token_ids,
)
)
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")

View File

@ -20,21 +20,15 @@ from http import HTTPStatus
from typing import Annotated, Any, Literal
import model_hosting_container_standards.sagemaker as sagemaker_standards
import prometheus_client
import pydantic
import regex as re
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.routing import Mount
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import VllmConfig
@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
ErrorInfo,
ErrorResponse,
GenerateRequest,
GenerateResponse,
ResponsesRequest,
ResponsesResponse,
StreamingResponsesResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponseVariant,
TranslationRequest,
@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
OpenAIServingModels,
)
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.entrypoints.openai.serving_tokens import ServingTokens
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription,
OpenAIServingTranslation,
@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
cli_env_setup,
@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory
@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
router = APIRouter()
class PrometheusResponse(Response):
media_type = prometheus_client.CONTENT_TYPE_LATEST
def mount_metrics(app: FastAPI):
"""Mount prometheus metrics to a FastAPI app."""
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
"/server_info",
],
registry=registry,
).add().instrument(app).expose(app, response_class=PrometheusResponse)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
try:
await engine_client(raw_request).check_health()
return Response(status_code=200)
except EngineDeadError:
return Response(status_code=503)
@router.get("/load")
async def get_server_load_metrics(request: Request):
# This endpoint returns the current server load metrics.
@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
@router.post("/pause")
async def pause_generation(
raw_request: Request,
wait_for_inflight_requests: bool = Query(False),
clear_cache: bool = Query(True),
) -> JSONResponse:
"""Pause generation requests to allow weight updates.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
aborts any in-flight requests immediately.
clear_cache: Whether to clear KV/prefix caches after draining.
"""
engine = engine_client(raw_request)
try:
await engine.pause_generation(
wait_for_inflight_requests=wait_for_inflight_requests,
clear_cache=clear_cache,
)
return JSONResponse(
content={"status": "paused"},
status_code=HTTPStatus.OK.value,
)
except ValueError as err:
return JSONResponse(
content={"error": str(err)},
status_code=HTTPStatus.BAD_REQUEST.value,
)
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to pause generation")
return JSONResponse(
content={"error": f"Failed to pause generation: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
@router.post("/resume")
async def resume_generation(raw_request: Request) -> JSONResponse:
"""Resume generation after a pause."""
engine = engine_client(raw_request)
try:
await engine.resume_generation()
return JSONResponse(
content={"status": "resumed"},
status_code=HTTPStatus.OK.value,
)
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to resume generation")
return JSONResponse(
content={"error": f"Failed to resume generation: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
@router.get("/is_paused")
async def is_paused(raw_request: Request) -> JSONResponse:
"""Return the current pause status."""
engine = engine_client(raw_request)
try:
paused = await engine.is_paused()
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to fetch pause status")
return JSONResponse(
content={"error": f"Failed to fetch pause status: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
return JSONResponse(content={"is_paused": paused})
@router.post(
"/tokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/detokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
try:
generator = await handler.create_detokenize(request, raw_request)
except OverflowError as e:
raise RequestValidationError(errors=[str(e)]) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
def maybe_register_tokenizer_info_endpoint(args):
"""Conditionally register the tokenizer info endpoint if enabled."""
if getattr(args, "enable_tokenizer_info_endpoint", False):
@router.get("/tokenizer_info")
async def get_tokenizer_info(raw_request: Request):
"""Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(
content=result.model_dump(),
status_code=result.error.code
if isinstance(result, ErrorResponse)
else 200,
)
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = models(raw_request)
@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
await engine_client(raw_request).reset_mm_cache()
return Response(status_code=200)
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/wake_up")
async def wake_up(raw_request: Request):
tags = raw_request.query_params.getlist("tags")
if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.get("/is_sleeping")
async def is_sleeping(raw_request: Request):
logger.info("check whether the engine is sleeping")
is_sleeping = await engine_client(raw_request).is_sleeping()
return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/collective_rpc")
async def collective_rpc(raw_request: Request):
try:
@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
return Response(status_code=200)
response: list[Any] = []
for result in results:
if result is None or isinstance(result, (dict, list)):
if result is None or isinstance(result, dict | list):
response.append(result)
else:
response.append(str(result))
return JSONResponse(content={"results": response})
@router.post(
"/scale_elastic_ep",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": dict},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def scale_elastic_ep(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
new_data_parallel_size = body.get("new_data_parallel_size")
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
if new_data_parallel_size is None:
raise HTTPException(
status_code=400, detail="new_data_parallel_size is required"
)
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
raise HTTPException(
status_code=400, detail="new_data_parallel_size must be a positive integer"
)
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
raise HTTPException(
status_code=400, detail="drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
global _scaling_elastic_ep
_scaling_elastic_ep = True
client = engine_client(raw_request)
try:
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
return JSONResponse(
{
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
}
)
except TimeoutError as e:
raise HTTPException(
status_code=408,
detail="Scale failed due to request drain timeout "
f"after {drain_timeout} seconds",
) from e
except Exception as e:
logger.error("Scale failed: %s", e)
raise HTTPException(status_code=500, detail="Scale failed") from e
finally:
_scaling_elastic_ep = False
@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
logger.warning_once(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file:
return None
@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
return self.app(scope, receive, send_with_request_id)
# Global variable to track scaling state
_scaling_elastic_ep = False
class ScalingMiddleware:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] != "http":
return self.app(scope, receive, send)
# Check global scaling state
global _scaling_elastic_ep
if _scaling_elastic_ep:
# Return 503 Service Unavailable response
response = JSONResponse(
content={
"error": "The model is currently scaling. Please try again later."
},
status_code=503,
)
return response(scope, receive, send)
return self.app(scope, receive, send)
def _extract_content_from_chunk(chunk_data: dict) -> str:
"""Extract content from a streaming response chunk."""
try:
@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
)
else:
app = FastAPI(lifespan=lifespan)
app.state.args = args
from vllm.entrypoints.serve import register_vllm_serve_api_routers
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
register_dynamic_lora_routes(router)
register_vllm_serve_api_routers(app)
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
app.root_path = args.root_path
mount_metrics(app)
from vllm.entrypoints.pooling import register_pooling_api_routers
register_pooling_api_routers(app)
@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
)
app = sagemaker_standards.bootstrap(app)
# Optional endpoints
if args.tokens_only:
@app.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
return app
@ -1515,7 +1081,7 @@ async def init_app_state(
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config
state.args = args
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
@ -1839,7 +1405,6 @@ async def run_server_worker(
args,
client_config=client_config,
) as engine_client:
maybe_register_tokenizer_info_endpoint(args)
app = build_app(args)
await init_app_state(engine_client, app.state, args)

View File

@ -1598,6 +1598,20 @@ def serialize_messages(msgs):
return [serialize_message(msg) for msg in msgs] if msgs else None
class ResponseRawMessageAndToken(OpenAIBaseModel):
"""Class to show the raw message.
If message / tokens diverge, tokens is the source of truth"""
message: str
tokens: list[int]
type: Literal["raw_message_tokens"] = "raw_message_tokens"
ResponseInputOutputMessage: TypeAlias = (
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
)
class ResponsesResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
created_at: int = Field(default_factory=lambda: int(time.time()))
@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel):
# These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed
# see serialize_input_messages and serialize_output_messages
input_messages: list[ChatCompletionMessageParam] | None = None
output_messages: list[ChatCompletionMessageParam] | None = None
input_messages: ResponseInputOutputMessage | None = None
output_messages: ResponseInputOutputMessage | None = None
# --8<-- [end:responses-extra-params]
# NOTE: openAI harmony doesn't serialize TextContent properly,
@ -1658,8 +1672,8 @@ class ResponsesResponse(OpenAIBaseModel):
output: list[ResponseOutputItem],
status: ResponseStatus,
usage: ResponseUsage | None = None,
input_messages: list[ChatCompletionMessageParam] | None = None,
output_messages: list[ChatCompletionMessageParam] | None = None,
input_messages: ResponseInputOutputMessage | None = None,
output_messages: ResponseInputOutputMessage | None = None,
) -> "ResponsesResponse":
incomplete_details: IncompleteDetails | None = None
if status == "incomplete":

View File

@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse,
FunctionCall,
FunctionDefinition,
GenerateRequest,
GenerateResponse,
ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt

View File

@ -86,6 +86,7 @@ from vllm.entrypoints.openai.protocol import (
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseInputOutputMessage,
ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent,
ResponsesRequest,
@ -629,8 +630,8 @@ class OpenAIServingResponses(OpenAIServing):
# "completed" is implemented as the "catch-all" for now.
status: ResponseStatus = "completed"
input_messages = None
output_messages = None
input_messages: ResponseInputOutputMessage | None = None
output_messages: ResponseInputOutputMessage | None = None
if self.use_harmony:
assert isinstance(context, HarmonyContext)
output = self._make_response_output_items_with_harmony(context)
@ -670,12 +671,10 @@ class OpenAIServingResponses(OpenAIServing):
output = self._make_response_output_items(request, final_output, tokenizer)
# TODO: context for non-gptoss models doesn't use messages
# so we can't get them out yet
if request.enable_response_messages:
raise NotImplementedError(
"enable_response_messages is currently only supported for gpt-oss"
)
input_messages = context.input_messages
output_messages = context.output_messages
# Calculate usage.
assert final_res.prompt_token_ids is not None
num_tool_output_tokens = 0

View File

@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
completion,
create_chat_completion,
create_completion,
health,
validate_json_request,
)
from vllm.entrypoints.openai.protocol import (
@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
from vllm.entrypoints.serve.instrumentator.health import health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import FastAPI
def register_vllm_serve_api_routers(app: FastAPI):
from vllm.entrypoints.serve.lora.api_router import (
attach_router as attach_lora_router,
)
attach_lora_router(app)
from vllm.entrypoints.serve.elastic_ep.api_router import (
attach_router as attach_elastic_ep_router,
)
attach_elastic_ep_router(app)
from vllm.entrypoints.serve.profile.api_router import (
attach_router as attach_profile_router,
)
attach_profile_router(app)
from vllm.entrypoints.serve.sleep.api_router import (
attach_router as attach_sleep_router,
)
attach_sleep_router(app)
from vllm.entrypoints.serve.tokenize.api_router import (
attach_router as attach_tokenize_router,
)
attach_tokenize_router(app)
from vllm.entrypoints.serve.disagg.api_router import (
attach_router as attach_disagg_router,
)
attach_disagg_router(app)
from vllm.entrypoints.serve.rlhf.api_router import (
attach_router as attach_rlhf_router,
)
attach_rlhf_router(app)
from vllm.entrypoints.serve.instrumentator.metrics import (
attach_router as attach_metrics_router,
)
attach_metrics_router(app)
from vllm.entrypoints.serve.instrumentator.health import (
attach_router as attach_health_router,
)
attach_health_router(app)

View File

@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
)
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
)
from vllm.entrypoints.serve.disagg.serving import (
ServingTokens,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return tokenization(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
if getattr(app.state.args, "tokens_only", False):
@router.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
app.include_router(router)

View File

@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from pydantic import BaseModel, Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs,
Logprob,
SamplingParams,
StreamOptions,
)
from vllm.utils import random_uuid
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
class GenerateResponseChoice(BaseModel):
index: int
logprobs: ChatCompletionLogProbs | None = None
# per OpenAI spec this is the default
finish_reason: str | None = "stop"
token_ids: list[int] | None = None
class GenerateResponse(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices: list[GenerateResponseChoice]
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob

View File

@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
)
from vllm.entrypoints.serve.elastic_ep.middleware import (
get_scaling_elastic_ep,
set_scaling_elastic_ep,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post(
"/scale_elastic_ep",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": dict},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def scale_elastic_ep(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
new_data_parallel_size = body.get("new_data_parallel_size")
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
if new_data_parallel_size is None:
raise HTTPException(
status_code=400, detail="new_data_parallel_size is required"
)
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
raise HTTPException(
status_code=400,
detail="new_data_parallel_size must be a positive integer",
)
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
raise HTTPException(
status_code=400, detail="drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
set_scaling_elastic_ep(True)
client = engine_client(raw_request)
try:
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
return JSONResponse(
{
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
}
)
except TimeoutError as e:
raise HTTPException(
status_code=408,
detail="Scale failed due to request drain timeout "
f"after {drain_timeout} seconds",
) from e
except Exception as e:
logger.error("Scale failed: %s", e)
raise HTTPException(status_code=500, detail="Scale failed") from e
finally:
set_scaling_elastic_ep(False)
@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
def attach_router(app: FastAPI):
app.include_router(router)

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Awaitable
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
# Global variable to track scaling state
_scaling_elastic_ep = False
def get_scaling_elastic_ep():
return _scaling_elastic_ep
def set_scaling_elastic_ep(value):
global _scaling_elastic_ep
_scaling_elastic_ep = value
class ScalingMiddleware:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] != "http":
return self.app(scope, receive, send)
# Check global scaling state
if get_scaling_elastic_ep():
# Return 503 Service Unavailable response
response = JSONResponse(
content={
"error": "The model is currently scaling. Please try again later."
},
status_code=503,
)
return response(scope, receive, send)
return self.app(scope, receive, send)

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, Request
from fastapi.responses import Response
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
from vllm.v1.engine.exceptions import EngineDeadError
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
try:
await engine_client(raw_request).check_health()
return Response(status_code=200)
except EngineDeadError:
return Response(status_code=503)
def attach_router(app):
app.include_router(router)

View File

@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import re
import prometheus_client
from fastapi import FastAPI, Response
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.routing import Mount
from vllm.v1.metrics.prometheus import get_prometheus_registry
class PrometheusResponse(Response):
media_type = prometheus_client.CONTENT_TYPE_LATEST
def attach_router(app: FastAPI):
"""Mount prometheus metrics to a FastAPI app."""
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
"/server_info",
],
registry=registry,
).add().instrument(app).expose(app, response_class=PrometheusResponse)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)

View File

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import model_hosting_container_standards.sagemaker as sagemaker_standards
from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, Response
from vllm import envs
from vllm.entrypoints.openai.api_server import models, validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def register_dynamic_lora_routes(router: APIRouter):
def attach_router(app: FastAPI):
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
return
logger.warning(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@sagemaker_standards.register_load_adapter_handler(
request_shape={
"lora_name": "body.name",
@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
return Response(status_code=200, content=response)
return router
# register the router
app.include_router(router)

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def attach_router(app: FastAPI):
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
logger.warning_once(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
app.include_router(router)

View File

View File

@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, FastAPI, Query, Request
from fastapi.responses import JSONResponse
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post("/pause")
async def pause_generation(
raw_request: Request,
wait_for_inflight_requests: bool = Query(False),
clear_cache: bool = Query(True),
) -> JSONResponse:
"""Pause generation requests to allow weight updates.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
aborts any in-flight requests immediately.
clear_cache: Whether to clear KV/prefix caches after draining.
"""
engine = engine_client(raw_request)
try:
await engine.pause_generation(
wait_for_inflight_requests=wait_for_inflight_requests,
clear_cache=clear_cache,
)
return JSONResponse(
content={"status": "paused"},
status_code=HTTPStatus.OK.value,
)
except ValueError as err:
return JSONResponse(
content={"error": str(err)},
status_code=HTTPStatus.BAD_REQUEST.value,
)
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to pause generation")
return JSONResponse(
content={"error": f"Failed to pause generation: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
@router.post("/resume")
async def resume_generation(raw_request: Request) -> JSONResponse:
"""Resume generation after a pause."""
engine = engine_client(raw_request)
try:
await engine.resume_generation()
return JSONResponse(
content={"status": "resumed"},
status_code=HTTPStatus.OK.value,
)
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to resume generation")
return JSONResponse(
content={"error": f"Failed to resume generation: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
@router.get("/is_paused")
async def is_paused(raw_request: Request) -> JSONResponse:
"""Return the current pause status."""
engine = engine_client(raw_request)
try:
paused = await engine.is_paused()
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to fetch pause status")
return JSONResponse(
content={"error": f"Failed to fetch pause status: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
return JSONResponse(content={"is_paused": paused})
def attach_router(app: FastAPI):
app.include_router(router)

View File

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import JSONResponse, Response
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/wake_up")
async def wake_up(raw_request: Request):
tags = raw_request.query_params.getlist("tags")
if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.get("/is_sleeping")
async def is_sleeping(raw_request: Request):
logger.info("check whether the engine is sleeping")
is_sleeping = await engine_client(raw_request).is_sleeping()
return JSONResponse(content={"is_sleeping": is_sleeping})
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
logger.warning(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
app.include_router(router)

View File

@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeRequest,
TokenizeResponse,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
router = APIRouter()
@router.post(
"/tokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/detokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
try:
generator = await handler.create_detokenize(request, raw_request)
except OverflowError as e:
raise RequestValidationError(errors=[str(e)]) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
def attach_router(app: FastAPI):
if getattr(app.state.args, "enable_tokenizer_info_endpoint", False):
"""Conditionally register the tokenizer info endpoint if enabled."""
@router.get("/tokenizer_info")
async def get_tokenizer_info(raw_request: Request):
"""Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(
content=result.model_dump(),
status_code=result.error.code
if isinstance(result, ErrorResponse)
else 200,
)
app.include_router(router)

View File

@ -767,8 +767,10 @@ class CompressedTensorsConfig(QuantizationConfig):
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
return self.target_scheme_map[matched_target]
scheme_dict = self.target_scheme_map[matched_target]
if scheme_dict.get("format") is None:
scheme_dict["format"] = self.quant_format
return scheme_dict
return None

View File

@ -7,7 +7,11 @@ from enum import Enum
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationStrategy,
)
from torch.nn.parameter import Parameter
import vllm.envs as envs
@ -142,10 +146,26 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
# are supported + check if the layer is being ignored.
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise
group_size = weight_quant.group_size or -1
valid_format_and_bits = (
weight_quant.num_bits in WNA16_SUPPORTED_BITS
and format == CompressionFormat.pack_quantized.value
)
if not valid_format_and_bits:
raise ValueError(
"For Fused MoE layers, only format: ",
f"{CompressionFormat.pack_quantized.value} ",
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
f"but got format: {CompressionFormat.pack_quantized.value} "
f" and bits: {weight_quant.num_bits}",
)
# Prefer to use the MarlinMoE kernel when it is supported.
if (
not check_moe_marlin_supports_layer(layer, group_size)
@ -161,12 +181,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
@ -176,15 +196,15 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
):
return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
weight_quant, input_quant, layer.moe_config
)
else:
raise RuntimeError(
@ -650,17 +670,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
)
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
@ -698,11 +720,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant
)
self.use_cutlass = not self.block_quant and (
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
CompressedTensorsConfig._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant
)
or self.is_fp8_w8a8_sm100
)
self.disable_expert_map = False
@ -1261,16 +1285,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@ -1414,36 +1436,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
self.layer_name = layer_name
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
assert config.symmetric, "Only symmetric quantization is supported for MoE"
self.weight_quant = weight_quant
self.input_quant = input_quant
assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
)
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
self.strategy = weight_quant.strategy
self.group_size = weight_quant.group_size
self.actorder = weight_quant.actorder
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS
):
raise ValueError(
"For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}",
)
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
self.use_marlin = True
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
def create_weights(
self,
@ -1812,35 +1825,26 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
self.weight_quant = weight_quant
self.input_quant = input_quant
# Extract properties from weight_quant
self.num_bits = weight_quant.num_bits
self.packed_factor = 32 // weight_quant.num_bits
self.strategy = weight_quant.strategy
# channelwise is not supported by this kernel
assert config.strategy == "group"
self.group_size = config.group_size
assert weight_quant.strategy == "group"
self.group_size = weight_quant.group_size
# grouped actorder isn't supported by this kernel
assert config.actorder != "group"
assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS
):
raise ValueError(
"For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}",
)
assert weight_quant.actorder != "group"
assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
)
def create_weights(
self,
@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.quant_config = quant_config
self.weight_quant = weight_quant
self.input_quant = input_quant
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
aq = self.quant_config.target_scheme_map["Linear"].get("input_activations")
# Must be dynamic per-token activations
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
if (
input_quant.strategy != QuantizationStrategy.TOKEN
or not input_quant.dynamic
):
raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
# Weight can be channel-wise (group_size=None) or group-wise
self.group_size = wq.group_size if (wq.group_size is not None) else -1
if wq.num_bits != 4:
self.group_size = (
weight_quant.group_size if (weight_quant.group_size is not None) else -1
)
if weight_quant.num_bits != 4:
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
# CPU only

View File

@ -921,7 +921,17 @@ def gguf_quant_weights_iterator(
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name not in ("F32", "BF16", "F16"):
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
if weight_type.name == "BF16" and tensor.data.dtype == np.uint8:
# BF16 is currently the only "quantization" type that isn't
# actually quantized but is read as a raw byte tensor.
# Reinterpret as `torch.bfloat16` tensor.
weight = weight.view(np.uint16)
if reader.byte_order == "S":
# GGUF endianness != system endianness
weight = weight.byteswap()
param = torch.tensor(weight).view(torch.bfloat16)
else:
param = torch.tensor(weight)
yield name, param

View File

@ -785,6 +785,7 @@ class HunYuanVLForConditionalGeneration(
SupportsQuant,
SupportsXDRoPE,
):
merge_by_field_config = True
multimodal_cpu_fields = {"image_grid_thw"}
# To ensure correct weight loading and mapping.

View File

@ -21,6 +21,8 @@ class MediaWithBytes(Generic[_T]):
The wrapper delegates attribute access to the underlying media object,
making it behave transparently like the wrapped type (e.g., PIL.Image).
NOTE: Currently, this wrapper is used only for the image modality.
"""
media: _T

View File

@ -32,6 +32,7 @@ if TYPE_CHECKING:
from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from .base import MediaWithBytes
from .processing import MultiModalHashes
else:
@ -59,7 +60,7 @@ Represents a single audio
item, which can be passed to a HuggingFace `AudioProcessor`.
"""
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
"""
A `transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`.

View File

@ -134,11 +134,17 @@ class EmbeddingItems(
or a list of embedding tensors (one per item).
"""
def _unwrap(
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
) -> torch.Tensor:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> torch.Tensor:
return self.data[index]
return self._unwrap(self.data[index])
def get_processor_data(self) -> Mapping[str, object]:
return {}
@ -478,7 +484,7 @@ class MultiModalDataParser:
return ImageEmbeddingItems(data)
if (
isinstance(data, PILImage.Image)
isinstance(data, (PILImage.Image, MediaWithBytes))
or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 3
):

View File

@ -67,8 +67,9 @@ class MediaConnector:
to set num_frames for video, set
`--media-io-kwargs '{"video":{"num_frames":40}}'`
connection: HTTP connection client to download media contents.
allowed_local_media_path: A local directory to load media files
from.
allowed_local_media_path: A local directory to load media files from.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
"""
super().__init__()
@ -123,16 +124,16 @@ class MediaConnector:
"Cannot load local files without `--allowed-local-media-path`."
)
filepath = Path(url2pathname(url_spec.path))
filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}."
f"of `--allowed-local-media-path {allowed_local_media_path}`."
)
return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
if (
self.allowed_media_domains
and url_spec.hostname not in self.allowed_media_domains
@ -489,9 +490,16 @@ def fetch_audio(
Args:
audio_url: URL of the audio file to fetch.
audio_io_kwargs: Additional kwargs passed to handle audio IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
"""
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_audio(audio_url)
@ -503,9 +511,16 @@ def fetch_image(
Args:
image_url: URL of the image file to fetch.
image_io_kwargs: Additional kwargs passed to handle image IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
"""
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_image(image_url)
@ -517,7 +532,14 @@ def fetch_video(
Args:
video_url: URL of the video file to fetch.
video_io_kwargs: Additional kwargs passed to handle video IO.
Warning:
This method has direct access to local files and is only intended
to be called by user code. Never call this from the online server!
"""
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
media_connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path="/",
)
return media_connector.fetch_video(video_url)

View File

@ -267,7 +267,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
return frames, metadata
class VideoMediaIO(MediaIO[npt.NDArray]):
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
def __init__(
self,
image_io: ImageMediaIO,

View File

@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin):
attention_mask = input_ids.ne(self.pad_id)
text_inputs["attention_mask"] = attention_mask
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids]
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
return_tensors = kwargs.pop("return_tensors", None)

View File

@ -7,6 +7,7 @@ from collections.abc import Iterable
from typing import Any
from vllm import envs
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
@ -1037,6 +1038,7 @@ class Scheduler(SchedulerInterface):
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
cudagraph_stats = model_runner_output.cudagraph_stats
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
@ -1219,7 +1221,9 @@ class Scheduler(SchedulerInterface):
finished_req_ids.clear()
if (
stats := self.make_stats(spec_decoding_stats, kv_connector_stats)
stats := self.make_stats(
spec_decoding_stats, kv_connector_stats, cudagraph_stats
)
) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
@ -1420,6 +1424,7 @@ class Scheduler(SchedulerInterface):
self,
spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: KVConnectorStats | None = None,
cudagraph_stats: CUDAGraphStat | None = None,
) -> SchedulerStats | None:
if not self.log_stats:
return None
@ -1444,6 +1449,7 @@ class Scheduler(SchedulerInterface):
kv_cache_eviction_events=eviction_events,
spec_decoding_stats=spec_stats,
kv_connector_stats=connector_stats_payload,
cudagraph_stats=cudagraph_stats,
)
def make_spec_decoding_stats(

View File

@ -10,6 +10,7 @@ from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphLogging
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging = SpecDecodingLogging()
kv_transfer_config = self.vllm_config.kv_transfer_config
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
self.cudagraph_logging = None
if self.vllm_config.observability_config.cudagraph_metrics:
self.cudagraph_logging = CUDAGraphLogging(
self.vllm_config.compilation_config.cudagraph_mode,
self.vllm_config.compilation_config.cudagraph_capture_sizes,
)
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_connector_logging.observe(kv_connector_stats)
if (
self.cudagraph_logging is not None
and scheduler_stats.cudagraph_stats is not None
):
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
if not self.aggregated:
self.last_scheduler_stats = scheduler_stats
if mm_cache_stats:
@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_connector_logging.log(log_fn=log_fn)
if self.cudagraph_logging is not None:
self.cudagraph_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING:
@ -183,6 +184,8 @@ class SchedulerStats:
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
running_lora_adapters: dict[str, int] = field(default_factory=dict)
cudagraph_stats: CUDAGraphStat | None = None
@dataclass
class RequestStateStats:

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
@ -169,6 +170,9 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None
# information related to cudagraph execution
cudagraph_stats: CUDAGraphStat | None = None
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):

View File

@ -27,7 +27,7 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.layer import Attention, MLAAttention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationMode,
@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple):
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
ec_connector_output: ECConnectorOutput | None
cudagraph_stats: CUDAGraphStat | None
class GPUModelRunner(
@ -2417,10 +2418,7 @@ class GPUModelRunner(
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if (
self.compilation_config.pass_config.enable_sequence_parallelism
and tp_size > 1
):
if self.compilation_config.pass_config.enable_sp and tp_size > 1:
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens
@ -2758,7 +2756,11 @@ class GPUModelRunner(
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
) -> tuple[
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
CUDAGraphMode,
BatchDescriptor,
UBatchSlices | None,
torch.Tensor | None,
CUDAGraphStat | None,
]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
uniform_decode = (
@ -2823,7 +2825,22 @@ class GPUModelRunner(
# num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
cudagraph_stats = None
if self.vllm_config.observability_config.cudagraph_metrics:
cudagraph_stats = CUDAGraphStat(
num_unpadded_tokens=num_tokens,
num_padded_tokens=batch_descriptor.num_tokens,
num_paddings=batch_descriptor.num_tokens - num_tokens,
runtime_mode=str(cudagraph_mode),
)
return (
cudagraph_mode,
batch_descriptor,
ubatch_slices,
num_tokens_across_dp,
cudagraph_stats,
)
@torch.inference_mode()
def execute_model(
@ -2921,6 +2938,7 @@ class GPUModelRunner(
batch_desc,
ubatch_slices,
num_tokens_across_dp,
cudagraph_stats,
) = self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
@ -3070,6 +3088,7 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
)
self.kv_connector_output = kv_connector_output
return None
@ -3105,6 +3124,7 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
@ -3220,6 +3240,7 @@ class GPUModelRunner(
if self.supports_mm_inputs
else None,
num_nans_in_logits=num_nans_in_logits,
cudagraph_stats=cudagraph_stats,
)
if not self.use_async_scheduling:
@ -3940,7 +3961,7 @@ class GPUModelRunner(
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,

View File

@ -552,7 +552,7 @@ class Worker(WorkerBase):
if (
parallel_config.pipeline_parallel_size > 1
and compilation_config.pass_config.enable_sequence_parallelism
and compilation_config.pass_config.enable_sp
and forward_pass
):
# currently only supported by V1 GPUModelRunner
@ -564,7 +564,7 @@ class Worker(WorkerBase):
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _ = (
_, batch_desc, _, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),

View File

@ -342,7 +342,7 @@ def is_residual_scattered_for_sp(
partition), SP is always applied
- Otherwise, SP is only applied for specific shapes in compile_sizes
"""
if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism:
if not vllm_config.compilation_config.pass_config.enable_sp:
return False
tp = vllm_config.parallel_config.tensor_parallel_size