Merge branch 'main' into imarkov/eplb_optimizations

This commit is contained in:
ilmarkov 2025-12-11 11:38:02 +00:00
commit b57a04516e
167 changed files with 4820 additions and 2221 deletions

View File

@ -36,6 +36,11 @@ function cpu_tests() {
set -e
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
# Run model tests
docker exec cpu-test bash -c "
set -e
pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model"
# Run kernel tests
docker exec cpu-test bash -c "
set -e

View File

@ -1,73 +0,0 @@
#!/usr/bin/env bash
set -euxo pipefail
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
THRESHOLD=${1:-0.25}
NUM_Q=${2:-1319}
PORT=${3:-8030}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
wait_for_server() {
local port=$1
timeout 600 bash -c '
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
sleep 1
done'
}
MODEL="deepseek-ai/DeepSeek-V2-lite"
# Set BACKENDS based on platform
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
# ROCm platform
BACKENDS=("allgather_reducescatter")
# Disable MOE padding for ROCm since it is causing eplb to fail
export VLLM_ROCM_MOE_PADDING=0
else
# Non-ROCm platform (CUDA/other)
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
kill "${SERVER_PID}" 2>/dev/null || true
for _ in {1..20}; do
kill -0 "${SERVER_PID}" 2>/dev/null || break
sleep 0.5
done
kill -9 "${SERVER_PID}" 2>/dev/null || true
fi
}
trap cleanup EXIT
for BACK in "${BACKENDS[@]}"; do
VLLM_DEEP_GEMM_WARMUP=skip \
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
--tensor-parallel-size 2 \
--data-parallel-size 2 \
--enable-expert-parallel \
--enable-eplb \
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
--trust-remote-code \
--max-model-len 2048 \
--port $PORT &
SERVER_PID=$!
wait_for_server $PORT
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json"
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
python3 - <<PY
import json; acc=json.load(open('${OUT}'))['accuracy']
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
PY
cleanup
SERVER_PID=
sleep 1
PORT=$((PORT+1))
done

View File

@ -50,7 +50,6 @@ for BACK in "${BACKENDS[@]}"; do
--data-parallel-size 2 \
--enable-expert-parallel \
--enable-eplb \
--eplb-config '{"window_size":200,"step_interval":600}' \
--trust-remote-code \
--max-model-len 2048 \
--port $PORT &

View File

@ -1,74 +0,0 @@
#!/usr/bin/env bash
set -euxo pipefail
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
THRESHOLD=${1:-0.25}
NUM_Q=${2:-1319}
PORT=${3:-8040}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
wait_for_server() {
local port=$1
timeout 600 bash -c '
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
sleep 1
done'
}
MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct"
# Set BACKENDS based on platform
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
# ROCm platform
BACKENDS=("allgather_reducescatter")
# Disable MOE padding for ROCm since it is causing eplb to fail
export VLLM_ROCM_MOE_PADDING=0
else
# Non-ROCm platform (CUDA/other)
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
kill "${SERVER_PID}" 2>/dev/null || true
for _ in {1..20}; do
kill -0 "${SERVER_PID}" 2>/dev/null || break
sleep 0.5
done
kill -9 "${SERVER_PID}" 2>/dev/null || true
fi
}
trap cleanup EXIT
for BACK in "${BACKENDS[@]}"; do
VLLM_DEEP_GEMM_WARMUP=skip \
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--enable-eplb \
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
--trust-remote-code \
--max-model-len 2048 \
--gpu-memory-utilization 0.9 \
--port $PORT &
SERVER_PID=$!
wait_for_server $PORT
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
OUT="${OUT_DIR}/${TAG}_${BACK}.json"
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
python3 - <<PY
import json; acc=json.load(open('${OUT}'))['accuracy']
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
PY
cleanup
SERVER_PID=
sleep 1
PORT=$((PORT+1))
done

View File

@ -1379,22 +1379,4 @@ steps:
num_gpus: 2
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
- label: DeepSeek V2-Lite Async EPLB Accuracy
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1

View File

@ -32,12 +32,11 @@ def benchmark_propose(args):
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
max_model_len=args.num_token + args.num_spec_token,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
seed=0,
trust_remote_code=False,
)
proposer = NgramProposer(

View File

@ -574,7 +574,7 @@ async def benchmark(
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
"Total token throughput (tok/s):", metrics.total_token_throughput
)
)

View File

@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
in MLA (Multi-head Latent Attention) prefill.
This validates that the optimization from commit 8d4142bd is beneficial across
various batch sizes, not just the originally tested batch size of 32768.
"""
import time
from collections.abc import Callable
import torch
# DeepSeek-V3 MLA dimensions
NUM_HEADS = 128
QK_NOPE_HEAD_DIM = 128
PE_DIM = 64
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
"""Original torch.cat approach with expand."""
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
"""Optimized direct copy approach (avoids expand + cat overhead)."""
k = torch.empty(
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
dtype=k_nope.dtype,
device=k_nope.device,
)
k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe
return k
def benchmark_method(
method: Callable,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
num_warmup: int = 10,
num_iters: int = 100,
) -> float:
"""Benchmark a concatenation method and return mean latency in ms."""
# Warmup
for _ in range(num_warmup):
_ = method(k_nope, k_pe)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iters):
_ = method(k_nope, k_pe)
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / num_iters * 1000 # Convert to ms
@torch.inference_mode()
def run_benchmark(dtype: torch.dtype, dtype_name: str):
"""Run benchmark for a specific dtype."""
torch.set_default_device("cuda")
# Batch sizes to test (powers of 2 from 32 to 65536)
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
print("=" * 80)
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
print("=" * 80)
print(
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
f"k_pe=[B, 1, {PE_DIM}]"
)
print(f"dtype: {dtype_name}")
print()
print(
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
f"{'Speedup':>8} | {'Reduction':>10}"
)
print("-" * 70)
results = []
for batch_size in batch_sizes:
# Create input tensors (generate in float32 then convert for FP8 compatibility)
k_nope = torch.randn(
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
).to(dtype)
k_pe = torch.randn(
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
).to(dtype)
# Benchmark both methods
cat_time = benchmark_method(cat_method, k_nope, k_pe)
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
speedup = cat_time / direct_time
reduction = (1 - direct_time / cat_time) * 100
results.append((batch_size, cat_time, direct_time, speedup, reduction))
print(
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
)
print("=" * 80)
# Summary statistics
speedups = [r[3] for r in results]
print("\nSpeedup summary:")
print(f" Min: {min(speedups):.2f}x")
print(f" Max: {max(speedups):.2f}x")
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
# Find crossover point
crossover_batch = None
for batch_size, _, _, speedup, _ in results:
if speedup >= 1.0:
crossover_batch = batch_size
break
print("\nConclusion:")
if crossover_batch:
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
# Filter for large batches (>= 512 which is typical for prefill)
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
if large_batch_speedups:
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
print(" - MLA prefill typically uses large batches, so optimization is effective")
return results
@torch.inference_mode()
def main():
# Test bfloat16
print("\n")
run_benchmark(torch.bfloat16, "bfloat16")
# Test float8_e4m3fn
print("\n")
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
if __name__ == "__main__":
main()

View File

@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
endif()
# Build ACL with CMake
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
set(CMAKE_BUILD_TYPE "Release")
set(ARM_COMPUTE_ARCH "armv8.2-a")
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
set(ARM_COMPUTE_BUILD_TESTING "OFF")
set(_cmake_config_cmd
${CMAKE_COMMAND} -G Ninja -B build
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF

View File

@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata(
input.casual = casual;
input.isa = isa;
input.enable_kv_split = enable_kv_split;
TORCH_CHECK(casual, "Only supports casual mask for now.");
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {

View File

@ -186,7 +186,7 @@ struct AttentionMetadata {
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
// * q_tile_size * 4, partial output, max + sum (float)
// Reduction scratchpad contains:
// - flags: bool array to indicate wether the split is finished
// - flags: bool array to indicate whether the split is finished
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
// - max, sum: 2 * split_num * q_tile_size * 4
class AttentionScratchPad {

View File

@ -617,7 +617,7 @@ struct MacheteCollectiveMma {
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// with `SwapAB ? N : M -> M` since we dont support SwapAB
// with `SwapAB ? N : M -> M` since we don't support SwapAB
// clang-format off
template<class ProblemShape>
static bool

View File

@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd;
nPrRnd -= div1 * 3;
int rnds3 = N / nPrRnd;
nPrRnd -= div1;
int rnds4 = N / nPrRnd;
nPrRnd -= div1;
int rnds5 = N / nPrRnd;
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
int rnds[13];
for (int i = 0; i < 13; i++) {
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
nPrRnd -= div1;
}
for (int i = 12; i >= 0; i--)
if (rnds[0] == rnds[i]) return (div2 - i);
}
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} \
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
switch (N_in) {
case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
WVSPLIT_TILE(sYT, 1)
break;
case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
WVSPLIT_TILE(sYT, 2)
break;
case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
WVSPLIT_TILE(sYT, 3)
break;
case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
WVSPLIT_TILE(sYT, 4)
break;
default:
throw std::runtime_error(

View File

@ -84,7 +84,7 @@ Total input tokens: 1369
Total generated tokens: 2212
Request throughput (req/s): 1.73
Output token throughput (tok/s): 382.89
Total Token throughput (tok/s): 619.85
Total token throughput (tok/s): 619.85
---------------Time to First Token----------------
Mean TTFT (ms): 71.54
Median TTFT (ms): 73.88

View File

@ -21,30 +21,20 @@ The mental model is that server-level metrics help explain the values of request
### v1 Metrics
In v1, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix:
In v1, an extensive set of metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix, for example:
- `vllm:num_requests_running` (Gauge) - Number of requests currently running.
- `vllm:num_requests_waiting` (Gauge) - Number of requests currently waiting.
- `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (01).
- `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries.
- `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits.
- `vllm:mm_cache_queries` (Counter) - (For multimodal models) Number of multimodal cache queries.
- `vllm:mm_cache_hits` (Counter) - (For multimodal models) Number of multimodal cache hits.
- `vllm:num_preemptions_total` (Counter) - Number of preemptions.
- `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed.
- `vllm:generation_tokens_total` (Counter) - Total number of generated tokens.
- `vllm:iteration_tokens_total` (Histogram) - Histogram of tokens processed in each engine step.
- `vllm:cache_config_info` (Gauge) - Information about the cache configuration.
- `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason).
- `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts.
- `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts.
- `vllm:request_params_n` (Histogram) - Histogram of request parameter n.
- `vllm:request_params_max_tokens` - (Histogram) - Histogram of max_tokens parameter in requests.
- `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT).
- `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency.
- `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency.
- `vllm:request_queue_time_seconds` (Histogram) - Time spent in the queue.
- `vllm:request_inference_time_seconds` (Histogram) - Request inference time.
- `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time.
- `vllm:request_decode_time_seconds` (Histogram) - Request decode time.

View File

@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you
## Deprecation announcement
!!! warning "Deprecations"
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0.
- `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.

View File

@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
```bash
# Example UCX configuration, adjust according to your enviroment
# Example UCX configuration, adjust according to your environment
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
```

View File

@ -61,7 +61,7 @@ Now let´s see an example for each of the cases, starting with the `choice`, as
print(completion.choices[0].message.content)
```
The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template:
The next example shows how to use the `regex`. The supported regex syntax depends on the structured output backend. For example, `xgrammar`, `guidance`, and `outlines` use Rust-style regex, while `lm-format-enforcer` uses Python's `re` module. The idea is to generate an email address, given a simple regex template:
??? code

View File

@ -26,3 +26,4 @@ The backends below live **outside** the main `vllm` repository and follow the
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |

View File

@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import logging
from pathlib import Path
from typing import Literal
logger = logging.getLogger("mkdocs")
ROOT_DIR = Path(__file__).parent.parent.parent.parent
DOCS_DIR = ROOT_DIR / "docs"
GENERATED_METRICS_DIR = DOCS_DIR / "generated" / "metrics"
# Files to scan for metric definitions - each will generate a separate table
METRIC_SOURCE_FILES = [
{"path": "vllm/v1/metrics/loggers.py", "output": "general.md"},
{
"path": "vllm/v1/spec_decode/metrics.py",
"output": "spec_decode.md",
},
{
"path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py",
"output": "nixl_connector.md",
},
]
class MetricExtractor(ast.NodeVisitor):
"""AST visitor to extract metric definitions."""
def __init__(self):
self.metrics: list[dict[str, str]] = []
def visit_Call(self, node: ast.Call) -> None:
"""Visit function calls to find metric class instantiations."""
metric_type = self._get_metric_type(node)
if metric_type:
name = self._extract_kwarg(node, "name")
documentation = self._extract_kwarg(node, "documentation")
if name:
self.metrics.append(
{
"name": name,
"type": metric_type,
"documentation": documentation or "",
}
)
self.generic_visit(node)
def _get_metric_type(self, node: ast.Call) -> str | None:
"""Determine if this call creates a metric and return its type."""
metric_type_map = {
"_gauge_cls": "gauge",
"_counter_cls": "counter",
"_histogram_cls": "histogram",
}
if isinstance(node.func, ast.Attribute):
return metric_type_map.get(node.func.attr)
return None
def _extract_kwarg(self, node: ast.Call, key: str) -> str | None:
"""Extract a keyword argument value from a function call."""
for keyword in node.keywords:
if keyword.arg == key:
return self._get_string_value(keyword.value)
return None
def _get_string_value(self, node: ast.AST) -> str | None:
"""Extract string value from an AST node."""
if isinstance(node, ast.Constant):
return str(node.value) if node.value is not None else None
return None
def extract_metrics_from_file(filepath: Path) -> list[dict[str, str]]:
"""Parse a Python file and extract all metric definitions."""
try:
with open(filepath, encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=str(filepath))
extractor = MetricExtractor()
extractor.visit(tree)
return extractor.metrics
except Exception as e:
raise RuntimeError(f"Failed to parse {filepath}: {e}") from e
def generate_markdown_table(metrics: list[dict[str, str]]) -> str:
"""Generate a markdown table from extracted metrics."""
if not metrics:
return "No metrics found.\n"
# Sort by type, then by name
metrics_sorted = sorted(metrics, key=lambda m: (m["type"], m["name"]))
lines = []
lines.append("| Metric Name | Type | Description |")
lines.append("|-------------|------|-------------|")
for metric in metrics_sorted:
name = metric["name"]
metric_type = metric["type"].capitalize()
doc = metric["documentation"].replace("\n", " ").strip()
lines.append(f"| `{name}` | {metric_type} | {doc} |")
return "\n".join(lines) + "\n"
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
"""Generate metrics documentation tables from source files."""
logger.info("Generating metrics documentation")
# Create generated directory if it doesn't exist
GENERATED_METRICS_DIR.mkdir(parents=True, exist_ok=True)
total_metrics = 0
for source_config in METRIC_SOURCE_FILES:
source_path = source_config["path"]
output_file = source_config["output"]
filepath = ROOT_DIR / source_path
if not filepath.exists():
raise FileNotFoundError(f"Metrics source file not found: {filepath}")
logger.debug("Extracting metrics from: %s", source_path)
metrics = extract_metrics_from_file(filepath)
logger.debug("Found %d metrics in %s", len(metrics), source_path)
# Generate and write the markdown table for this source
table_content = generate_markdown_table(metrics)
output_path = GENERATED_METRICS_DIR / output_file
with open(output_path, "w", encoding="utf-8") as f:
f.write(table_content)
total_metrics += len(metrics)
logger.info(
"Generated metrics table: %s (%d metrics)",
output_path.relative_to(ROOT_DIR),
len(metrics),
)
logger.info(
"Total metrics generated: %d across %d files",
total_metrics,
len(METRIC_SOURCE_FILES),
)

View File

@ -316,10 +316,13 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
### Remove softmax from PoolingParams
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
We are going to remove `softmax` and `activation` from `PoolingParams` in v0.15. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
### as_reward_model
!!! warning
We are going to remove `--convert reward` in v0.15, use `--convert embed` instead.
Pooling models now default support all pooling, you can use it without any settings.
- Extracting hidden states prefers using `token_embed` task.

View File

@ -24,7 +24,7 @@ There are two distinct modes supported for online deployments - self-contained w
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs.
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. When sizing DP deployments, remember that `--max-num-seqs` applies per DP rank.
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
@ -80,6 +80,18 @@ When deploying large DP sizes using this method, the API server process can beco
![DP Internal LB Diagram](../assets/deployment/dp_internal_lb.png)
</figure>
## Hybrid Load Balancing
Hybrid load balancing sits between the internal and external approaches. Each node runs its own API server(s) that only queue requests to the data-parallel engines colocated on that node. An upstream load balancer (for example, an ingress controller or traffic router) spreads user requests across those per-node endpoints.
Enable this mode with `--data-parallel-hybrid-lb` while still launching every node with the global data-parallel size. The key differences from internal load balancing are:
- You must provide `--data-parallel-size-local` and `--data-parallel-start-rank` so each node knows which ranks it owns.
- Not compatible with `--headless` since every node exposes an API endpoint.
- Scale `--api-server-count` per node based on the number of local ranks
In this configuration, each node keeps scheduling decisions local, which reduces cross-node traffic and avoids single node bottlenecks at larger DP sizes.
## External Load Balancing
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.

View File

@ -40,10 +40,12 @@ EP_SIZE = TP_SIZE × DP_SIZE
Where:
- `TP_SIZE`: Tensor parallel size (always 1 for now)
- `TP_SIZE`: Tensor parallel size
- `DP_SIZE`: Data parallel size
- `EP_SIZE`: Expert parallel size (computed automatically)
When EP is enabled, MoE layers use expert parallelism instead of tensor parallelism, while attention layers continue to use tensor parallelism if `TP_SIZE > 1`.
### Example Command
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
@ -81,7 +83,7 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
--api-server-count=8 # Number of API servers for load handling (scaling this out to # local ranks is recommended)
# Node 2 (Secondary - headless mode, no API server)
vllm serve deepseek-ai/DeepSeek-V3-0324 \
@ -119,9 +121,6 @@ While MoE models are typically trained so that each expert receives a similar nu
Enable EPLB with the `--enable-eplb` flag.
!!! note "Model Support"
Currently only DeepSeek V3 architecture is supported.
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
### EPLB Parameters
@ -134,6 +133,8 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T
| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 |
| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` |
| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` |
| `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` |
| `policy` | The policy type for expert parallel load balancing | `"default"` |
For example:
@ -183,6 +184,26 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available.
## Advanced Configuration
### Performance Optimization
- **DeepEP kernels**: The `high_throughput` and `low_latency` kernels are optimized for disaggregated serving and may show poor performance for mixed workloads
- **Dual Batch Overlap**: Use `--enable-dbo` to overlap all-to-all communication with compute. See [Dual Batch Overlap](../design/dbo.md) for more details.
- **Async scheduling (experimental)**: Try `--async-scheduling` to overlap scheduling with model execution.
### Troubleshooting
- **`non-zero status: 7 cannot register cq buf`**: When using Infiniband/RoCE, make sure host VM and pods show `ulimit -l` "unlimited".
- **`init failed for transport: IBGDA`**: The InfiniBand GDA kernel modules are missing. Run `tools/ep_kernels/configure_system_drivers.sh` on each GPU node and reboot. Also fixes error `NVSHMEM API called before NVSHMEM initialization has completed`.
- **NVSHMEM peer disconnect**: Usually a networking misconfiguration. If deploying via Kubernetes, verify that every pod runs with `hostNetwork: true`, `securityContext.privileged: true` to access Infiniband.
### Benchmarking
- Use simulator flags `VLLM_MOE_ROUTING_SIMULATION_STRATEGY=uniform_random` and `VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1` so token routing is balanced across EP ranks.
- Increasing `VLLM_MOE_DP_CHUNK_SIZE` may increase throughput by increasing the maximum batch size for inter-rank token transfers. This may cause DeepEP to throw `assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2`, which can be fixed by increasing environment variable `NVSHMEM_QP_DEPTH`.
## Disaggregated Serving (Prefill/Decode Split)
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
@ -273,3 +294,9 @@ except Exception as e:
print(f"❌ Error during disaggregated serving: {e}")
print("Check that both prefill and decode instances are running and accessible")
```
### Benchmarking
- To simulate the decode deployment of disaggregated serving, pass `--kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both"}'` to the `vllm serve` invocation. The connector populates KV cache with random values so decode can be profiled in isolation.
- **CUDAGraph capture**: Use `--compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'` to enable CUDA graph capture for decode only and save KV cache.

View File

@ -33,11 +33,19 @@ Then query the endpoint to get the latest metrics from the server:
The following metrics are exposed:
??? code
## General Metrics
```python
--8<-- "vllm/engine/metrics.py:metrics-definitions"
```
--8<-- "docs/generated/metrics/general.md"
## Speculative Decoding Metrics
--8<-- "docs/generated/metrics/spec_decode.md"
## NIXL KV Connector Metrics
--8<-- "docs/generated/metrics/nixl_connector.md"
## Deprecation Policy
Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1`
but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch,

View File

@ -422,7 +422,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(

View File

@ -77,7 +77,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()

View File

@ -158,7 +158,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)

View File

@ -158,7 +158,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)

View File

@ -2031,7 +2031,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)

View File

@ -1382,7 +1382,7 @@ def run_generate(
model,
question: str,
image_urls: list[str],
seed: int | None,
seed: int,
tensor_parallel_size: int | None,
):
req_data = model_example_map[model](question, image_urls)
@ -1416,7 +1416,7 @@ def run_chat(
model: str,
question: str,
image_urls: list[str],
seed: int | None,
seed: int,
tensor_parallel_size: int | None,
):
req_data = model_example_map[model](question, image_urls)
@ -1494,7 +1494,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(

View File

@ -16,7 +16,7 @@ import requests
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds

View File

@ -305,7 +305,7 @@ def get_query(modality: QueryModality):
raise ValueError(msg)
def run_encode(model: str, modality: QueryModality, seed: int | None):
def run_encode(model: str, modality: QueryModality, seed: int):
query = get_query(modality)
req_data = model_example_map[model](query)
@ -335,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: int | None):
print("-" * 50)
def run_score(model: str, modality: QueryModality, seed: int | None):
def run_score(model: str, modality: QueryModality, seed: int):
query = get_query(modality)
req_data = model_example_map[model](query)
@ -390,7 +390,7 @@ def parse_args():
parser.add_argument(
"--seed",
type=int,
default=None,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()

View File

@ -51,6 +51,7 @@ hooks:
- docs/mkdocs/hooks/remove_announcement.py
- docs/mkdocs/hooks/generate_examples.py
- docs/mkdocs/hooks/generate_argparse.py
- docs/mkdocs/hooks/generate_metrics.py
- docs/mkdocs/hooks/url_schemes.py
plugins:

View File

@ -1,2 +1,2 @@
lmcache
lmcache >= 0.3.10.post1
nixl >= 0.7.1 # Required for disaggregated prefill

View File

@ -75,7 +75,7 @@ torchgeo==0.7.0
mteb==2.1.2
# Data processing
xgrammar==0.1.27
xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
# Test async scheduling
# Utilities

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import logging
from contextlib import nullcontext
from unittest.mock import patch
@ -13,7 +12,6 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer
@ -290,7 +288,7 @@ def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
),
@ -442,62 +440,3 @@ def test_cudagraph_sizes_post_init(
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)
def test_pass_config_deprecation(caplog_vllm):
caplog_vllm.set_level(logging.WARNING)
# Clear cache to ensure warnings are re-issued
_print_warning_once.cache_clear()
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
caplog_vllm.clear()
config = PassConfig(enable_fusion=True)
assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True
assert config.enable_fusion is True
# Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is True
# Test enable_noop -> eliminate_noops
caplog_vllm.clear()
config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True
assert config.enable_noop is True
# Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True
assert config.enable_sequence_parallelism is True
# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear()
config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True
assert config.enable_async_tp is True
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is True
# Test hash consistency
config_old = PassConfig(enable_fusion=True)
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
assert config_old.compute_hash() == config_new.compute_hash()
config_old = PassConfig(enable_async_tp=True)
config_new = PassConfig(fuse_gemm_comms=True)
assert config_old.compute_hash() == config_new.compute_hash()

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
@ -152,13 +155,79 @@ GROUP_SHAPES = [
]
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
@ -173,10 +242,14 @@ def test_fusion_rmsnorm_quant(
num_tokens,
eps,
group_shape,
model_class,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
@ -209,12 +282,24 @@ def test_fusion_rmsnorm_quant(
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
model = model_class(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
cuda_force_torch=cuda_force_torch,
)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
@ -243,7 +328,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
FUSED_OPS,
@ -24,6 +25,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return [FUSED_OPS[kNvfp4Quant]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
]
def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
model_class: type[
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
)
with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)
fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
assert fusion_pass.matched_count == 1
assert sum([p.matched_count for p in fusion_passes]) == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -741,7 +741,7 @@ class VllmRunner:
tokenizer_name: str | None = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
seed: int | None = 0,
seed: int = 0,
max_model_len: int | None = 1024,
dtype: str = "auto",
disable_log_stats: bool = True,

View File

@ -350,21 +350,35 @@ def test_human_readable_model_len():
assert args.max_model_len == 1_000_000
args = parser.parse_args(["--max-model-len", "10k"])
assert args.max_model_len == 10_000
args = parser.parse_args(["--max-model-len", "2g"])
assert args.max_model_len == 2_000_000_000
args = parser.parse_args(["--max-model-len", "2t"])
assert args.max_model_len == 2_000_000_000_000
# Capital
args = parser.parse_args(["--max-model-len", "3K"])
assert args.max_model_len == 1024 * 3
assert args.max_model_len == 2**10 * 3
args = parser.parse_args(["--max-model-len", "10M"])
assert args.max_model_len == 2**20 * 10
args = parser.parse_args(["--max-model-len", "4G"])
assert args.max_model_len == 2**30 * 4
args = parser.parse_args(["--max-model-len", "4T"])
assert args.max_model_len == 2**40 * 4
# Decimal values
args = parser.parse_args(["--max-model-len", "10.2k"])
assert args.max_model_len == 10200
# ..truncated to the nearest int
args = parser.parse_args(["--max-model-len", "10.212345k"])
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
assert args.max_model_len == 10212
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
assert args.max_model_len == 10212345
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
assert args.max_model_len == 10212345123
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
assert args.max_model_len == 10212345123456
# Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError):
args = parser.parse_args(["--max-model-len", invalid])
parser.parse_args(["--max-model-len", invalid])

View File

@ -0,0 +1,228 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_chat = OpenAIServingChat(
engine,
models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
):
return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, request_prompts, engine_prompts
return (
[{"role": "user", "content": "Test"}],
[[1, 2, 3]],
[{"prompt_token_ids": [1, 2, 3]}],
)
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
return serving_chat
@pytest.mark.asyncio
async def test_chat_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
max_tokens=10,
stream=False,
)
response = await serving_chat.create_chat_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
max_tokens=10,
stream=True,
)
response = await serving_chat.create_chat_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"

View File

@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_completion = OpenAIServingCompletion(
engine,
models,
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
):
return dict(engine_prompt), {}
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_completion
@pytest.mark.asyncio
async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_completion = _build_serving_completion(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=False,
)
response = await serving_completion.create_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_completion = _build_serving_completion(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=True,
)
response = await serving_completion.create_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"

View File

@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing
@pytest.mark.asyncio
async def test_raise_if_error_raises_generation_error():
"""test _raise_if_error raises GenerationError"""
# create a minimal OpenAIServing instance
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# test that error finish_reason raises GenerationError
with pytest.raises(GenerationError) as exc_info:
serving._raise_if_error("error", "test-request-id")
assert str(exc_info.value) == "Internal server error"
assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
# test that other finish_reasons don't raise
serving._raise_if_error("stop", "test-request-id") # should not raise
serving._raise_if_error("length", "test-request-id") # should not raise
serving._raise_if_error(None, "test-request-id") # should not raise
@pytest.mark.asyncio
async def test_convert_generation_error_to_response():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to ErrorResponse
error_response = serving._convert_generation_error_to_response(gen_error)
assert isinstance(error_response, ErrorResponse)
assert error_response.error.type == "InternalServerError"
assert error_response.error.message == "Internal server error"
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_convert_generation_error_to_streaming_response():
"""test _convert_generation_error_to_streaming_response output"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to streaming error response
error_json = serving._convert_generation_error_to_streaming_response(gen_error)
assert isinstance(error_json, str)
assert "Internal server error" in error_json
assert "InternalServerError" in error_json

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.responses_utils import (
construct_chat_message_with_tool_call,
_construct_single_message_from_response_item,
construct_chat_messages_with_tool_call,
convert_tool_responses_to_completions_format,
)
@ -42,7 +44,43 @@ class TestResponsesUtils:
assert result == {"type": "function", "function": input_tool}
def test_construct_chat_message_with_tool_call(self):
def test_construct_chat_messages_with_tool_call(self):
"""Test construction of chat messages with tool calls."""
reasoning_item = ResponseReasoningItem(
id="lol",
summary=[],
type="reasoning",
content=[
Content(
text="Leroy Jenkins",
type="reasoning_text",
)
],
encrypted_content=None,
status=None,
)
mcp_tool_item = ResponseFunctionToolCall(
id="mcp_123",
call_id="call_123",
type="function_call",
status="completed",
name="python",
arguments='{"code": "123+456"}',
)
input_items = [reasoning_item, mcp_tool_item]
messages = construct_chat_messages_with_tool_call(input_items)
assert len(messages) == 1
message = messages[0]
assert message["role"] == "assistant"
assert message["reasoning"] == "Leroy Jenkins"
assert message["tool_calls"][0]["id"] == "call_123"
assert message["tool_calls"][0]["function"]["name"] == "python"
assert (
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
)
def test_construct_single_message_from_response_item(self):
item = ResponseReasoningItem(
id="lol",
summary=[],
@ -56,7 +94,7 @@ class TestResponsesUtils:
encrypted_content=None,
status=None,
)
formatted_item = construct_chat_message_with_tool_call(item)
formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant"
assert formatted_item["reasoning"] == "Leroy Jenkins"
@ -74,7 +112,7 @@ class TestResponsesUtils:
status=None,
)
formatted_item = construct_chat_message_with_tool_call(item)
formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant"
assert (
formatted_item["reasoning"]
@ -88,7 +126,7 @@ class TestResponsesUtils:
output="1234",
status="completed",
)
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
formatted_item = _construct_single_message_from_response_item(tool_call_output)
assert formatted_item["role"] == "tool"
assert formatted_item["content"] == "1234"
assert formatted_item["tool_call_id"] == "temp"
@ -102,7 +140,7 @@ class TestResponsesUtils:
status=None,
)
with pytest.raises(ValueError):
construct_chat_message_with_tool_call(item)
_construct_single_message_from_response_item(item)
output_item = ResponseOutputMessage(
id="msg_bf585bbbe3d500e0",
@ -119,6 +157,6 @@ class TestResponsesUtils:
type="message",
)
formatted_item = construct_chat_message_with_tool_call(output_item)
formatted_item = _construct_single_message_from_response_item(output_item)
assert formatted_item["role"] == "assistant"
assert formatted_item["content"] == "dongyi"

View File

@ -7,7 +7,8 @@ import math
import pytest
import torch
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
]
def get_attn_isa(
block_size: int | None = None,
dtype: torch.dtype | None = None,
):
if block_size and dtype:
return _get_attn_isa(dtype, block_size)
else:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
elif torch._C._cpu._is_amx_tile_supported():
return "amx"
else:
return "vec"
# rand number generation takes too much time, cache rand tensors
@functools.lru_cache(maxsize=128, typed=False)
def tensor_cache(
@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", QTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["neon"])
@pytest.mark.skipif(
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
reason="Not an Arm CPU.",
)
def test_varlen_with_paged_kv_normal_neon(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
@pytest.mark.parametrize("isa", [get_attn_isa()])
def test_varlen_with_paged_kv_softcap(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [True])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
@pytest.mark.parametrize("isa", [get_attn_isa()])
def test_varlen_with_paged_kv_alibi(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [True])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
@pytest.mark.parametrize("isa", [get_attn_isa()])
def test_varlen_with_paged_kv_sink(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],

View File

@ -70,12 +70,12 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
f"{torch.cuda.device_count()}"
)
# `cuda_graph_sizes=[16]` to reduce load time.
# `cudagraph_capture_sizes=[16]` to reduce load time.
with vllm_runner(
model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy",
cuda_graph_sizes=[16],
cudagraph_capture_sizes=[16],
) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model):

View File

@ -54,6 +54,10 @@ def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_info = torch.finfo(current_platform.fp8_dtype())
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
)
@torch.inference_mode()
def test_w8a8_block_fp8_cutlass_matmul():
# Test simple case where weight.shape % 128 != 0,
@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
assert rel_diff < 0.001
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),

View File

@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),

View File

@ -21,6 +21,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods

View File

@ -17,7 +17,6 @@ def test_idefics_multimodal(
with vllm_runner(
model_name="HuggingFaceM4/Idefics3-8B-Llama3",
runner="pooling",
task="classify",
convert="classify",
load_format="dummy",
max_model_len=512,
@ -86,7 +85,6 @@ def test_gemma_multimodal(
with vllm_runner(
model_name="google/gemma-3-4b-it",
runner="pooling",
task="classify",
convert="classify",
load_format="auto",
hf_overrides=update_config,

View File

@ -92,16 +92,19 @@ def run_test(
*,
tensor_parallel_size: int,
distributed_executor_backend: str | None = None,
dtype: str = "half",
) -> None:
prompt_list = PROMPTS * 10
expected_list = EXPECTED[model] * 10
with vllm_runner(
model,
dtype="half",
dtype=dtype,
max_model_len=448,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
# TODO (NickLucche) figure out output differences with non-eager and re-enable
enforce_eager=True,
) as vllm_model:
llm = vllm_model.llm
@ -120,12 +123,28 @@ def run_test(
@pytest.mark.core_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("dtype", ["half"])
@create_new_process_for_each_test()
def test_models(vllm_runner, model) -> None:
def test_models(vllm_runner, model, dtype) -> None:
run_test(
vllm_runner,
model,
tensor_parallel_size=1,
dtype=dtype,
)
@pytest.mark.cpu_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("dtype", ["half"])
def test_models_cpu(vllm_runner, model, dtype) -> None:
# @create_new_process_for_each_test() does not work for some runners
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
run_test(
vllm_runner,
model,
tensor_parallel_size=1,
dtype=dtype,
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import mimetypes
import os
@ -186,6 +187,7 @@ async def test_fetch_image_error_conversion():
connector.fetch_image(broken_img)
@pytest.mark.flaky(reruns=3, reruns_delay=5)
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
@ -198,8 +200,12 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
}
)
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
try:
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
except (TimeoutError, asyncio.TimeoutError) as e:
pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async

View File

@ -147,7 +147,7 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
"""
Regression test for handling videos with broken frames.
This test uses a pre-corrupted video file (assets/corrupted.mp4) that
contains broken/unreadable frames to verify the video loader handles
contains broken frames to verify the video loader handles
them gracefully without crashing and returns accurate metadata.
"""
with monkeypatch.context() as m:
@ -177,3 +177,125 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
f"Expected fewer than {metadata['total_num_frames']} frames, "
f"but loaded {frames.shape[0]} frames"
)
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_1")
class TestVideoBackendOverride1(VideoLoader):
"""Test loader that returns FAKE_OUTPUT_1 to verify backend selection."""
@classmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict]:
return FAKE_OUTPUT_1, {"video_backend": "test_video_backend_override_1"}
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_2")
class TestVideoBackendOverride2(VideoLoader):
"""Test loader that returns FAKE_OUTPUT_2 to verify backend selection."""
@classmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict]:
return FAKE_OUTPUT_2, {"video_backend": "test_video_backend_override_2"}
def test_video_media_io_backend_kwarg_override(monkeypatch: pytest.MonkeyPatch):
"""
Test that video_backend kwarg can override the VLLM_VIDEO_LOADER_BACKEND
environment variable.
This allows users to dynamically select a different video backend
via --media-io-kwargs without changing the global env var, which is
useful when plugins set a default backend but a specific request
needs a different one.
"""
with monkeypatch.context() as m:
# Set the env var to one backend
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_1")
imageio = ImageMediaIO()
# Without video_backend kwarg, should use env var backend
videoio_default = VideoMediaIO(imageio, num_frames=10)
frames_default, metadata_default = videoio_default.load_bytes(b"test")
np.testing.assert_array_equal(frames_default, FAKE_OUTPUT_1)
assert metadata_default["video_backend"] == "test_video_backend_override_1"
# With video_backend kwarg, should override env var
videoio_override = VideoMediaIO(
imageio, num_frames=10, video_backend="test_video_backend_override_2"
)
frames_override, metadata_override = videoio_override.load_bytes(b"test")
np.testing.assert_array_equal(frames_override, FAKE_OUTPUT_2)
assert metadata_override["video_backend"] == "test_video_backend_override_2"
def test_video_media_io_backend_kwarg_not_passed_to_loader(
monkeypatch: pytest.MonkeyPatch,
):
"""
Test that video_backend kwarg is consumed by VideoMediaIO and NOT passed
through to the underlying video loader's load_bytes method.
This ensures the kwarg is properly popped from kwargs before forwarding.
"""
@VIDEO_LOADER_REGISTRY.register("test_reject_video_backend_kwarg")
class RejectVideoBackendKwargLoader(VideoLoader):
"""Test loader that fails if video_backend is passed through."""
@classmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict]:
# This should never receive video_backend in kwargs
if "video_backend" in kwargs:
raise AssertionError(
"video_backend should be consumed by VideoMediaIO, "
"not passed to loader"
)
return FAKE_OUTPUT_1, {"received_kwargs": list(kwargs.keys())}
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_reject_video_backend_kwarg")
imageio = ImageMediaIO()
# Even when video_backend is provided, it should NOT be passed to loader
videoio = VideoMediaIO(
imageio,
num_frames=10,
video_backend="test_reject_video_backend_kwarg",
other_kwarg="should_pass_through",
)
# This should NOT raise AssertionError
frames, metadata = videoio.load_bytes(b"test")
np.testing.assert_array_equal(frames, FAKE_OUTPUT_1)
# Verify other kwargs are still passed through
assert "other_kwarg" in metadata["received_kwargs"]
def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch):
"""
Test that when video_backend kwarg is None or not provided,
VideoMediaIO falls back to VLLM_VIDEO_LOADER_BACKEND env var.
"""
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_2")
imageio = ImageMediaIO()
# Explicit None should fall back to env var
videoio_none = VideoMediaIO(imageio, num_frames=10, video_backend=None)
frames_none, metadata_none = videoio_none.load_bytes(b"test")
np.testing.assert_array_equal(frames_none, FAKE_OUTPUT_2)
assert metadata_none["video_backend"] == "test_video_backend_override_2"
# Not providing video_backend should also fall back to env var
videoio_missing = VideoMediaIO(imageio, num_frames=10)
frames_missing, metadata_missing = videoio_missing.load_bytes(b"test")
np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2)
assert metadata_missing["video_backend"] == "test_video_backend_override_2"

View File

@ -10,10 +10,14 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Fp8MoEMethod,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
MODELS = [
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
),
)
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
):
if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization")
if method_cls is Fp8MoEMethod and weight_block_size is None:
pytest.skip(
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
"If this is your use case, consider using a restore function like #26327"
)
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
weight_block_size=weight_block_size,
)
if method_cls is Fp8LinearMethod:
layer = torch.nn.Linear(1, 1)
method = method_cls(config)
method.create_weights(
layer=layer,
input_size_per_partition=1,
output_partition_sizes=[1],
input_size=1,
output_size=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
else:
layer = FusedMoE(
num_experts=1,
top_k=1,
hidden_size=1,
intermediate_size=1,
)
method = method_cls(config, layer)
method.create_weights(
layer=layer,
num_experts=1,
hidden_size=1,
intermediate_size_per_partition=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
method.use_marlin = use_marlin
# capture weights format during loading
original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
for name, param in layer.named_parameters()
]
# test loading
for name, shape, _ in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
# test reloading works after loading
# assuming that no reshaping occurred
for name, shape, original_weight_loader in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert weight_loader is original_weight_loader
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)

View File

@ -212,11 +212,11 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
task = "wikitext"
rtol = 0.1
# Smaller cuda_graph_sizes to speed up the test.
# Smaller cudagraph_capture_sizes to speed up the test.
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
),
tasks=task,
batch_size=64,

View File

@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "minimax_m2_append_think"
end_token = "</think>"
# MiniMax M2 model path
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
@pytest.fixture(scope="module")
def minimax_m2_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# =============================================================================
# MiniMaxM2AppendThinkReasoningParser behavior:
# - Prepends <think> to the beginning of the output
# - Does NOT separate reasoning and content
# - Returns everything as content (with <think> prepended)
# - reasoning is always None
#
# This parser is used when you want to keep the raw output with <think> added
# =============================================================================
# Case: simple output with end token
SIMPLE_OUTPUT = {
"output": "This is reasoning</think>This is response",
"reasoning": None,
"content": "<think>This is reasoning</think>This is response",
"is_reasoning_end": True,
}
# Case: output without end token (reasoning in progress)
NO_END_TOKEN = {
"output": "This is reasoning in progress",
"reasoning": None,
"content": "<think>This is reasoning in progress",
"is_reasoning_end": False,
}
# Case: only end token
ONLY_END_TOKEN = {
"output": "</think>This is response",
"reasoning": None,
"content": "<think></think>This is response",
"is_reasoning_end": True,
}
# Case: multiple lines
MULTIPLE_LINES = {
"output": "Line 1\nLine 2</think>Response 1\nResponse 2",
"reasoning": None,
"content": "<think>Line 1\nLine 2</think>Response 1\nResponse 2",
"is_reasoning_end": True,
}
# Case: empty output (non-streaming prepends <think>)
EMPTY = {
"output": "",
"reasoning": None,
"content": "<think>",
"is_reasoning_end": False,
}
# Case: empty output streaming (no tokens = no output)
EMPTY_STREAMING = {
"output": "",
"reasoning": None,
"content": None,
"is_reasoning_end": False,
}
# Case: special characters
SPECIAL_CHARS = {
"output": "Let me think... 1+1=2</think>Yes!",
"reasoning": None,
"content": "<think>Let me think... 1+1=2</think>Yes!",
"is_reasoning_end": True,
}
# Case: code in output
CODE_OUTPUT = {
"output": "```python\nprint('hi')\n```</think>Here's the code.",
"reasoning": None,
"content": "<think>```python\nprint('hi')\n```</think>Here's the code.",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(
False,
SIMPLE_OUTPUT,
id="simple_output",
),
pytest.param(
True,
SIMPLE_OUTPUT,
id="simple_output_streaming",
),
pytest.param(
False,
NO_END_TOKEN,
id="no_end_token",
),
pytest.param(
True,
NO_END_TOKEN,
id="no_end_token_streaming",
),
pytest.param(
False,
ONLY_END_TOKEN,
id="only_end_token",
),
pytest.param(
True,
ONLY_END_TOKEN,
id="only_end_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
False,
EMPTY,
id="empty",
),
pytest.param(
True,
EMPTY_STREAMING,
id="empty_streaming",
),
pytest.param(
False,
SPECIAL_CHARS,
id="special_chars",
),
pytest.param(
True,
SPECIAL_CHARS,
id="special_chars_streaming",
),
pytest.param(
False,
CODE_OUTPUT,
id="code_output",
),
pytest.param(
True,
CODE_OUTPUT,
id="code_output_streaming",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
minimax_m2_tokenizer,
):
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
minimax_m2_tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]

View File

@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "minimax_m2"
end_token = "</think>"
# MiniMax M2 model path
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
@pytest.fixture(scope="module")
def minimax_m2_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# =============================================================================
# MiniMax M2 specific behavior:
# - Model does NOT generate <think> start token
# - Model only generates </think> end token
# - All content before </think> is reasoning
# - All content after </think> is the actual response (content)
# =============================================================================
# Case: reasoning + end token + content (typical case)
SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
# Case: reasoning + end token only (no content after)
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
# Case: no end token yet (streaming in progress, all is reasoning)
NO_END_TOKEN = {
"output": "This is reasoning in progress",
"reasoning": "This is reasoning in progress",
"content": None,
"is_reasoning_end": False,
}
# Case: multiple lines of reasoning
MULTIPLE_LINES = {
"output": "First line\nSecond line</think>Response first line\nResponse second",
"reasoning": "First line\nSecond line",
"content": "Response first line\nResponse second",
"is_reasoning_end": True,
}
# Case: only end token (empty reasoning, immediate response)
SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the response",
"reasoning": "",
"content": "This is the response",
"is_reasoning_end": True,
}
# Case: only end token streaming (reasoning is None because it's just the token)
SHORTEST_REASONING_STREAMING = {
"output": "</think>This is the response",
"reasoning": None,
"content": "This is the response",
"is_reasoning_end": True,
}
# Case: empty output
EMPTY = {
"output": "",
"reasoning": "",
"content": None,
"is_reasoning_end": False,
}
# Case: empty streaming
EMPTY_STREAMING = {
"output": "",
"reasoning": None,
"content": None,
"is_reasoning_end": False,
}
# Case: long reasoning with special characters
SPECIAL_CHARS = {
"output": "Let me think... 1+1=2, right?</think>Yes, 1+1=2.",
"reasoning": "Let me think... 1+1=2, right?",
"content": "Yes, 1+1=2.",
"is_reasoning_end": True,
}
# Case: reasoning with code blocks
CODE_IN_REASONING = {
"output": "```python\nprint('hello')\n```</think>Here is the code.",
"reasoning": "```python\nprint('hello')\n```",
"content": "Here is the code.",
"is_reasoning_end": True,
}
TEST_CASES = [
# Core cases: no start token (MiniMax M2 actual behavior)
pytest.param(
False,
SIMPLE_REASONING,
id="simple_reasoning",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_reasoning_streaming",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_streaming",
),
pytest.param(
False,
NO_END_TOKEN,
id="no_end_token",
),
pytest.param(
True,
NO_END_TOKEN,
id="no_end_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING,
id="shortest_reasoning",
),
pytest.param(
True,
SHORTEST_REASONING_STREAMING,
id="shortest_reasoning_streaming",
),
pytest.param(
False,
EMPTY,
id="empty",
),
pytest.param(
True,
EMPTY_STREAMING,
id="empty_streaming",
),
pytest.param(
False,
SPECIAL_CHARS,
id="special_chars",
),
pytest.param(
True,
SPECIAL_CHARS,
id="special_chars_streaming",
),
pytest.param(
False,
CODE_IN_REASONING,
id="code_in_reasoning",
),
pytest.param(
True,
CODE_IN_REASONING,
id="code_in_reasoning_streaming",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
minimax_m2_tokenizer,
):
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
minimax_m2_tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == minimax_m2_tokenizer.convert_tokens_to_ids(
minimax_m2_tokenizer.tokenize(param_dict["content"])
)
else:
content = parser.extract_content_ids(output)
assert content == []

View File

@ -89,64 +89,6 @@ def test_update_config():
new_config3 = update_config(config3, {"a": "new_value"})
# Can remove once --task option is fully deprecated
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
[
("distilbert/distilgpt2", "generate", "none", "generate"),
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
("openai/whisper-small", "generate", "none", "transcription"),
],
)
def test_auto_task(
model_id, expected_runner_type, expected_convert_type, expected_task
):
config = ModelConfig(model_id, task="auto")
assert config.runner_type == expected_runner_type
assert config.convert_type == expected_convert_type
# Can remove once --task option is fully deprecated
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
[
("distilbert/distilgpt2", "pooling", "embed", "embed"),
("intfloat/multilingual-e5-small", "pooling", "embed", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"),
("openai/whisper-small", "pooling", "embed", "embed"),
],
)
def test_score_task(
model_id, expected_runner_type, expected_convert_type, expected_task
):
config = ModelConfig(model_id, task="score")
assert config.runner_type == expected_runner_type
assert config.convert_type == expected_convert_type
# Can remove once --task option is fully deprecated
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
[
("openai/whisper-small", "generate", "none", "transcription"),
],
)
def test_transcription_task(
model_id, expected_runner_type, expected_convert_type, expected_task
):
config = ModelConfig(model_id, task="transcription")
assert config.runner_type == expected_runner_type
assert config.convert_type == expected_convert_type
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_convert_type"),
[
@ -1085,7 +1027,7 @@ def test_vllm_config_explicit_overrides():
)
# Override one field but not others
pass_config = PassConfig(enable_noop=False)
pass_config = PassConfig(eliminate_noops=False)
compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig(
model_config=regular_model,

View File

@ -8,6 +8,7 @@ import pytest
import vllm.envs as envs
from vllm.envs import (
disable_envs_cache,
enable_envs_cache,
env_list_with_choices,
env_set_with_choices,
@ -57,6 +58,43 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
envs.__getattr__ = envs.__getattr__.__wrapped__
def test_getattr_with_reset(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1")
# __getattr__ is not decorated with functools.cache
assert not hasattr(envs.__getattr__, "cache_info")
# Enable envs cache and ignore ongoing environment changes
enable_envs_cache()
assert envs.VLLM_HOST_IP == "1.1.1.1"
# With cache enabled, the environment variable value is cached and unchanged
monkeypatch.setenv("VLLM_HOST_IP", "2.2.2.2")
assert envs.VLLM_HOST_IP == "1.1.1.1"
disable_envs_cache()
assert envs.VLLM_HOST_IP == "2.2.2.2"
# After cache disabled, the environment variable value would be synced
# with os.environ
monkeypatch.setenv("VLLM_HOST_IP", "3.3.3.3")
assert envs.VLLM_HOST_IP == "3.3.3.3"
def test_is_envs_cache_enabled() -> None:
assert not envs._is_envs_cache_enabled()
enable_envs_cache()
assert envs._is_envs_cache_enabled()
# Only wrap one-layer of cache, so we only need to
# call disable once to reset.
enable_envs_cache()
enable_envs_cache()
enable_envs_cache()
disable_envs_cache()
assert not envs._is_envs_cache_enabled()
disable_envs_cache()
assert not envs._is_envs_cache_enabled()
class TestEnvWithChoices:
"""Test cases for env_with_choices function."""

View File

@ -119,7 +119,7 @@ class RemoteOpenAIServer:
vllm_serve_args: list[str],
*,
env_dict: dict[str, str] | None = None,
seed: int | None = 0,
seed: int = 0,
auto_port: bool = True,
max_wait_seconds: float | None = None,
override_hf_configs: dict[str, Any] | None = None,
@ -283,7 +283,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None],
*,
env_dict: dict[str, str] | None = None,
seed: int | None = 0,
seed: int = 0,
auto_port: bool = True,
max_wait_seconds: float | None = None,
) -> None:

View File

@ -106,8 +106,8 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,

View File

@ -8,6 +8,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.platforms import current_platform
from vllm.sampling_params import StructuredOutputsParams
from vllm.v1.metrics.reader import Metric
@ -70,6 +71,18 @@ def test_without_spec_decoding(
(True, "uni", True, None, True),
]
if current_platform.is_rocm():
# On ROCm, Only test with structured_outputs (deterministic)
# and skip chunk_prefill (more variable).
test_configs = [
cfg
for cfg in test_configs
if not cfg[4] # skip chunk_prefill=True
]
test_sampling_params = [
p for p in test_sampling_params if p.get("structured_outputs") is not None
]
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
@ -108,7 +121,14 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True),
]
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
test_sampling_params,
is_testing_with_spec_decoding=True,
)
@dynamo_config.patch(cache_size_limit=16)
@ -117,15 +137,23 @@ def run_tests(
model: str,
test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]],
is_testing_with_spec_decoding: bool = False,
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m:
# avoid precision errors
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# lock matmul precision to full FP32
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list, list]] = []
for n, (
@ -145,6 +173,7 @@ def run_tests(
async_scheduling,
spec_config,
test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
)
outputs.append(test_results)
@ -174,17 +203,34 @@ def run_tests(
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
assert _all_logprobs_match(base_logprobs, test_logprobs)
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
if (
base_acceptance_rate is not None
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
else:
tolerance = 0.05
assert (
test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=5e-2)
== pytest.approx(base_acceptance_rate, rel=tolerance)
)
else:
# Currently the reported acceptance rate is expected to be
@ -215,6 +261,7 @@ def run_test(
async_scheduling: bool,
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
@ -233,6 +280,15 @@ def run_test(
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
# spec decoding test (TRITON_ATTN) for better precision.
# On others: always use float32.
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
with VllmRunner(
model,
max_model_len=512,
@ -242,7 +298,7 @@ def run_test(
# enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype="float32", # avoid precision errors
dtype=dtype,
speculative_config=spec_config,
disable_log_stats=False,
**cache_arg,
@ -302,11 +358,21 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
return len(lps_a) == len(lps_b) and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
if current_platform.is_rocm():
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6
return (
len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys()
and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)
)

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import multiprocessing
import sys
import traceback
import pytest
import torch
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)
# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget
# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)
# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
class SyncTracker:
@property
def count(self) -> int:
return sync_count.value
def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
)
yield SyncTracker()
# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
torch._dynamo.reset()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
"eagle3",
2,
id="eagle3-llama",
),
pytest.param(
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
"eagle",
2,
id="eagle-mla-deepseek",
),
]
@pytest.mark.parametrize(
"model,spec_model,method,num_spec_tokens",
SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
sync_tracker,
model: str,
spec_model: str,
method: str,
num_spec_tokens: int,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
llm = LLM(
model=model,
max_model_len=256,
speculative_config={
"method": method,
"num_speculative_tokens": num_spec_tokens,
"model": spec_model,
},
enforce_eager=True,
async_scheduling=True,
)
outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()

View File

@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 82.5% acceptance rate at the end.
assert last_accept_rate > 0.825
# Heuristic: expect at least 80.0% acceptance rate at the end.
assert last_accept_rate > 0.80
del spec_llm
torch.cuda.empty_cache()

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory
from vllm.v1.kv_cache_interface import FullAttentionSpec
def test_kv_cache_oom_no_memory():
from unittest.mock import MagicMock
config = MagicMock()
config.model_config.max_model_len = 2048
spec = {
"layer_0": FullAttentionSpec(
block_size=16,
num_kv_heads=8,
head_size=128,
dtype="float16",
)
}
with pytest.raises(ValueError):
check_enough_kv_cache_memory(config, spec, 0)
def test_kv_cache_oom_insufficient_memory(monkeypatch):
from unittest.mock import MagicMock
config = MagicMock()
config.model_config.max_model_len = 2048
config.cache_config.block_size = 16
config.parallel_config.tensor_parallel_size = 1
config.parallel_config.pipeline_parallel_size = 1
config.parallel_config.decode_context_parallel_size = 1
monkeypatch.setattr(
"vllm.v1.core.kv_cache_utils.max_memory_usage_bytes",
lambda c, s: 100 * 1024**3, # 100 GiB
)
spec = {
"layer_0": FullAttentionSpec(
block_size=16,
num_kv_heads=8,
head_size=128,
dtype="float16",
)
}
with pytest.raises(ValueError):
check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB

View File

@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
test that invalid blocks are evicted from prefix cache to prevent pollution.
verifies that when sync-loading fails, invalid blocks are removed from the
prefix cache hash table so future requests cannot match and reuse corrupted data.
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_invalid_blocks_evicted_prevents_cache_pollution(
fail_scheduler: Scheduler,
):
"""
verify invalid blocks are evicted to prevent future cache hits.
scenario:
1. request 1 loads externally-computed blocks (sync mode)
2. some blocks fail to load and are marked invalid
3. with fail policy, invalid blocks should be evicted from prefix cache
4. request is marked as FINISHED_ERROR
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
# request 1: will have invalid blocks
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
fail_scheduler.add_request(request=request1)
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request1.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify eviction later
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# cache the blocks to simulate they've been computed and cached
# (in real scenario blocks would be cached after compute)
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
# verify block has a hash (is cached) before reporting invalid blocks
assert block.block_hash is not None, (
f"block {invalid_block_id} should be cached (have a hash) before "
f"eviction test, but hash is None"
)
# report invalid blocks
model_runner_output = create_model_runner_output(
[request1],
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request finished with error (fail policy)
assert request1.status == RequestStatus.FINISHED_ERROR
# critical assertion: invalid block and all subsequent blocks should be evicted
# all blocks from invalid_block_idx onwards become invalid since they were
# computed based on the failed block
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is None, (
f"block {block_id} at index {idx} should have been evicted "
f"(hash reset to None), but hash is {block_obj.block_hash}. "
f"All blocks from index {invalid_block_idx} onwards should be evicted "
f"since they depend on the invalid block at index {invalid_block_idx}."
)
# verify cache contains exactly the valid blocks (before first affected block)
# and none of the invalid blocks (from first affected block onwards)
# valid blocks: all blocks before invalid_block_idx should be cached
for idx in range(invalid_block_idx):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is not None, (
f"valid block {block_id} at index {idx} should still be cached "
f"(have a hash), but hash is None. Only blocks from index "
f"{invalid_block_idx} onwards should be evicted."
)
# invalid blocks: verify they're not in the cached_block_hash_to_block map
cached_blocks = (
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
)
cached_block_ids = {
b.block_id
for blocks_val in cached_blocks._cache.values()
for b in (
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
)
}
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
assert block_id not in cached_block_ids, (
f"invalid block {block_id} at index {idx} should not be in cache hash table"
)

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.running) == 0
def test_error_propagation_async_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.waiting) == 0

View File

@ -0,0 +1,454 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for correctness in invalid block handling.
These tests verify correct behavior in three scenarios:
1. Sync recompute case: Blocks should not be freed for running requests
that need to recompute invalid blocks
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
3. Async recompute case: Invalid blocks should not be cached after transfer
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
@pytest.fixture
def recompute_scheduler():
"""scheduler with kv_load_failure_policy='recompute'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
return create_scheduler(vllm_config)
def test_sync_recompute_blocks_not_freed_for_running_requests(
recompute_scheduler: Scheduler,
):
"""
Test sync recompute case - blocks must not be freed for running requests.
When a running request has invalid blocks and retry_policy is 'recompute':
1. Request should remain in RUNNING state
2. num_computed_tokens should be truncated to invalid block boundary
3. Blocks should NOT be freed (request still needs them for recomputation)
4. Request should remain in scheduler.requests and scheduler.running
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be running with sync KV load
assert len(recompute_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert request.status == RequestStatus.RUNNING
# get the allocated block IDs before invalid blocks are reported
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
# store original num_computed_tokens for comparison
original_num_computed_tokens = request.num_computed_tokens
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=False, # not finished - should continue running
)
outputs = recompute_scheduler.update_from_output(
scheduler_output, model_runner_output
)
# critical assertions for recompute case:
# 1. request should still be RUNNING (not finished, not aborted)
assert request.status == RequestStatus.RUNNING, (
f"Request should remain RUNNING for recompute, got {request.status}"
)
# 2. num_computed_tokens should be truncated to first invalid block
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_truncated_tokens, (
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
f"got {request.num_computed_tokens}"
)
assert request.num_computed_tokens < original_num_computed_tokens, (
"num_computed_tokens should be reduced after invalid block detection"
)
# 3. no output should be generated (request is still running)
# the request should be skipped in the output loop
assert len(outputs) == 0 or request.request_id not in [
out.request_id for outs in outputs.values() for out in outs.outputs
], "No output should be generated for recompute requests"
# 4. request should still be in running queue
assert request in recompute_scheduler.running, (
"Request should remain in running queue for recomputation"
)
# 5. request should still be in scheduler.requests (not deleted)
assert request.request_id in recompute_scheduler.requests, (
"Request should not be deleted from scheduler.requests"
)
# 6. blocks should NOT be freed - verify blocks are still allocated
try:
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
assert allocated_blocks is not None
assert len(allocated_blocks[0]) > 0, (
"Blocks should still be allocated for recomputation"
)
except KeyError:
pytest.fail(
"Blocks were freed incorrectly! Running requests need their blocks "
"to recompute invalid portions."
)
# 7. verify request can be rescheduled in next step
scheduler_output_2 = recompute_scheduler.schedule()
# request should appear in the new schedule to recompute invalid blocks
scheduled_req_ids = [
req.request_id for req in scheduler_output_2.scheduled_new_reqs
]
if scheduler_output_2.num_scheduled_tokens:
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
assert (
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
), "Request should be reschedulable for recomputation"
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
"""
Test sync fail case - invalid blocks must be evicted from cache.
When a request fails with policy='fail' and has invalid blocks from sync loading:
1. Request should be finished with FINISHED_ERROR
2. Invalid blocks should be evicted from the KV cache
3. Valid blocks (if shared) should remain in cache
4. Future requests should not reuse the invalid blocks
This test verifies that invalid blocks are properly evicted to prevent
cache corruption and reuse of invalid data.
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# verify the block is in the block pool before we report it as invalid
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
assert block is not None
# report invalid blocks - request should fail
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request is finished with error
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
# verify output is generated
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
# verify the request was removed from scheduler
assert request.request_id not in fail_scheduler.requests
assert len(fail_scheduler.running) == 0
# critical: verify invalid block was actually freed from cache
# this is the key assertion - the invalid block should no longer be
# tracked by the KV cache manager for this request
# if it's still there, a future request could reuse the invalid data
try:
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
# if we get here, check if blocks were actually freed
if block_ids is not None and len(block_ids[0]) > 0:
pytest.fail(
f"Invalid blocks still tracked for finished request! "
f"Request {request.request_id} should have been freed but "
f"still has {len(block_ids[0])} blocks allocated."
)
# blocks list exists but is empty - this is fine, they were freed
except KeyError:
# expected - request completely removed from tracking
pass
# critical: verify invalid block was evicted from prefix cache
# the block should no longer have a hash (hash is reset on eviction)
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should have been evicted from cache "
f"(hash should be None), but hash is still {block.block_hash}"
)
def test_async_recompute_blocks_not_cached_when_invalid(
recompute_scheduler: Scheduler,
):
"""
Test async recompute case - invalid blocks not cached after transfer.
When async KV loading has invalid blocks and retry_policy is 'recompute':
1. Blocks are allocated but not cached yet
2. When async transfer completes, only valid blocks should be cached
3. Invalid blocks should never enter the prefix cache
This test verifies correctness, the failed_recving_kv_req_ids protection
ensures only valid blocks are cached when the transfer completes, and we
only evict blocks from cache that are already hashed in the block table.
"""
from unittest.mock import patch
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating async load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
# get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify it's not cached yet and stays uncached
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# verify block has no hash before invalid blocks are reported
assert block.block_hash is None, (
"Async loading blocks should not be cached yet (no hash)"
)
# report invalid blocks (transfer not finished yet)
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=None, # transfer NOT finished
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
# critical: spy on evict_blocks to verify it's NOT called for async blocks
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
evict_blocks_calls = []
def evict_blocks_spy(block_ids):
evict_blocks_calls.append(set(block_ids))
return original_evict_blocks(block_ids)
with patch.object(
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
):
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify evict_blocks was NOT called (async blocks excluded from eviction)
assert len(evict_blocks_calls) == 0, (
f"evict_blocks should not be called for async-only invalid blocks, "
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
)
# request should still be waiting (not finished with error due to recompute policy)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# verify num_computed_tokens was truncated to before invalid block
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_valid_tokens
# verify invalid block still has no hash (was not evicted)
assert block.block_hash is None, (
f"Async loading blocks shouldn't be cached or evicted. "
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
)
# now simulate async transfer completing
model_runner_output_2 = create_model_runner_output(
reqs=[],
finished_recving={request.request_id},
invalid_block_ids=None,
use_eos=False,
)
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
# verify request is now marked as finished receiving and ready to be processed
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# critical: verify invalid block still has no hash before recompute
# the async transfer invalid data was never cached
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should not be cached before recompute "
f"(hash should be None), but hash is {block.block_hash}"
)
# critical end-to-end test: spy on cache_blocks to verify it's called with
# the truncated num_computed_tokens value
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
cache_blocks_calls = []
def cache_blocks_spy(req, num_tokens):
cache_blocks_calls.append((req.request_id, num_tokens))
return original_cache_blocks(req, num_tokens)
with patch.object(
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
):
# call schedule() again - this triggers _update_waiting_for_remote_kv()
# which should call cache_blocks with the truncated value
recompute_scheduler.schedule()
# verify cache_blocks was called with the truncated value
assert len(cache_blocks_calls) == 1, (
f"cache_blocks should be called exactly once, "
f"got {len(cache_blocks_calls)} calls"
)
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
assert cached_req_id == request.request_id
assert cached_num_tokens == expected_valid_tokens, (
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
f"but was called with {cached_num_tokens}"
)
# request should now be RUNNING (scheduled immediately after transfer completes)
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
assert request.status == RequestStatus.RUNNING
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
assert request.num_computed_tokens >= expected_valid_tokens, (
f"num_computed_tokens should be at least {expected_valid_tokens}, "
f"got {request.num_computed_tokens}"
)
# request should no longer be in the failed/finished receiving sets
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
# request should be in the running queue
assert request in recompute_scheduler.running

View File

@ -88,8 +88,8 @@ def forward_attention(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
_seq_lens_cpu=seq_lens.cpu(),
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,

View File

@ -7,7 +7,7 @@ Here we break down the requirements in 2 steps:
1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this.
2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine.
2 is necessary for multi-node deployment.
Step 2 is necessary for multi-node deployment.
All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`.
@ -23,6 +23,6 @@ TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh
Additional step for multi-node deployment:
```bash
sudo bash configure_system_drivers.sh
sudo bash configure_system_drivers.sh # update-initramfs can take several minutes
sudo reboot # Reboot is required to load the new driver
```

View File

@ -24,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
@ -45,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -522,6 +501,142 @@ def _rocm_aiter_per_token_quant_fake(
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
)
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
)
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@ -557,7 +672,7 @@ class rocm_aiter_ops:
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
return cls.is_linear_enabled()
@classmethod
@if_aiter_supported
@ -632,14 +747,6 @@ class rocm_aiter_ops:
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
@ -699,27 +806,46 @@ class rocm_aiter_ops:
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
fake_impl=_rocm_aiter_group_fp8_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,

View File

@ -294,6 +294,12 @@ class AttentionImpl(ABC, Generic[T]):
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# Whether the attention impl supports Prefill Context Parallelism.
supports_pcp: bool = False
# Whether the attention impl(or ops) supports MTP
# when cp_kv_cache_interleave_size > 1
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False

View File

@ -252,35 +252,3 @@ def register_backend(
return lambda x: x
return decorator
# Backwards compatibility alias for plugins
class _BackendMeta(type):
"""Metaclass to provide deprecation warnings when accessing _Backend."""
def __getattribute__(cls, name: str):
if name not in ("__class__", "__mro__", "__name__"):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return getattr(AttentionBackendEnum, name)
def __getitem__(cls, name: str):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return AttentionBackendEnum[name]
class _Backend(metaclass=_BackendMeta):
"""Deprecated: Use AttentionBackendEnum instead.
This class is provided for backwards compatibility with plugins
and will be removed in a future release.
"""
pass

View File

@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy(
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from functools import cache
from typing import cast, get_args
@ -73,39 +72,18 @@ def _cached_get_attn_backend(
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls)
if "use_v1" in sig.parameters:
logger.warning_once(
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True, # use_v1
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"

View File

@ -788,7 +788,7 @@ async def benchmark(
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
"Total token throughput (tok/s):", metrics.total_token_throughput
)
)

View File

@ -5,6 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]

View File

@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
class AiterRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class AiterFusedAddRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)

View File

@ -17,7 +17,6 @@ from vllm.config.utils import (
Range,
config,
get_hash_factors,
handle_deprecated,
hash_factors,
)
from vllm.logger import init_logger
@ -127,27 +126,6 @@ class PassConfig:
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
@ -206,15 +184,7 @@ class PassConfig:
Any future fields that don't affect compilation should be excluded.
"""
ignored_fields = [
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
return hash_factors(get_hash_factors(self, set()))
@field_validator(
"fuse_norm_quant",
@ -224,12 +194,6 @@ class PassConfig:
"enable_sp",
"fuse_gemm_comms",
"fuse_allreduce_rms",
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
mode="wrap",
)
@classmethod
@ -242,49 +206,6 @@ class PassConfig:
def __post_init__(self) -> None:
# Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated(
self,
"enable_fusion",
["fuse_norm_quant", "fuse_act_quant"],
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_attn_fusion",
"fuse_attn_quant",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_sequence_parallelism",
"enable_sp",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_async_tp",
"fuse_gemm_comms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_fi_allreduce_fusion",
"fuse_allreduce_rms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_noop",
"eliminate_noops",
"v0.13.0 or v1.0.0, whichever is sooner",
)
if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once(

View File

@ -64,6 +64,11 @@ class KVTransferConfig:
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks (default)
'fail': immediately fail the request with an error finish reason"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -73,17 +73,6 @@ logger = init_logger(__name__)
RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal[
"auto",
"generate",
"embedding",
"embed",
"classify",
"score",
"reward",
"transcription",
"draft",
]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
"generate": ["generate", "transcription"],
"pooling": ["embedding", "embed", "classify", "score", "reward"],
"draft": ["draft"],
}
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [],
"pooling": ["embed", "classify", "reward"],
@ -126,12 +109,6 @@ class ModelConfig:
"""Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks."""
task: TaskOption | None = None
"""[DEPRECATED] The task to use the model for. If the model supports more
than one model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner.
"""
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
@ -335,7 +312,6 @@ class ModelConfig:
ignored_factors = {
"runner",
"convert",
"task",
"tokenizer",
"tokenizer_mode",
"seed",
@ -510,97 +486,6 @@ class ModelConfig:
is_generative_model = registry.is_text_generation_model(architectures, self)
is_pooling_model = registry.is_pooling_model(architectures, self)
def _task_to_convert(task: TaskOption) -> ConvertType:
if task == "embedding" or task == "embed":
return "embed"
if task == "classify":
return "classify"
if task == "reward":
logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score":
new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed"
return "none"
if self.task is not None:
runner: RunnerOption = "auto"
convert: ConvertOption = "auto"
msg_prefix = (
"The 'task' option has been deprecated and will be "
"removed in v0.13.0 or v1.0, whichever comes first."
)
msg_hint = "Please remove this option."
is_generative_task = self.task in _RUNNER_TASKS["generate"]
is_pooling_task = self.task in _RUNNER_TASKS["pooling"]
if is_generative_model and is_pooling_model:
if is_generative_task:
runner = "generate"
convert = "auto"
msg_hint = (
"Please replace this option with `--runner "
"generate` to continue using this model "
"as a generative model."
)
elif is_pooling_task:
runner = "pooling"
convert = "auto"
msg_hint = (
"Please replace this option with `--runner "
"pooling` to continue using this model "
"as a pooling model."
)
else: # task == "auto"
pass
elif is_generative_model or is_pooling_model:
if is_generative_task:
runner = "generate"
convert = "auto"
msg_hint = "Please remove this option"
elif is_pooling_task:
runner = "pooling"
convert = _task_to_convert(self.task)
msg_hint = (
"Please replace this option with `--convert "
f"{convert}` to continue using this model "
"as a pooling model."
)
else: # task == "auto"
pass
else:
# Neither generative nor pooling model - try to convert if possible
if is_pooling_task:
runner = "pooling"
convert = _task_to_convert(self.task)
msg_hint = (
"Please replace this option with `--runner pooling "
f"--convert {convert}` to continue using this model "
"as a pooling model."
)
else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError(
"The model should be a generative or "
"pooling model when task is set to "
f"{self.task!r}. Found: {debug_info}"
)
self.runner = runner
self.convert = convert
msg = f"{msg_prefix} {msg_hint}"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.runner_type = self._get_runner_type(architectures, self.runner)
self.convert_type = self._get_convert_type(
architectures, self.runner_type, self.convert
@ -903,6 +788,13 @@ class ModelConfig:
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert == "reward":
logger.warning(
"`--convert reward` is deprecated and will be removed in v0.15. "
"Please use `--convert embed` instead."
)
return "embed"
if convert != "auto":
return convert
@ -918,22 +810,6 @@ class ModelConfig:
return convert_type
def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed", "classify", "reward"]:
if self.registry.is_cross_encoder_model(architectures, self):
return "classify"
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type
return "embed"
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
quant_cfg = getattr(hf_config, "quantization_config", None)
if quant_cfg is None:

View File

@ -321,11 +321,6 @@ class ParallelConfig:
"num_redundant_experts."
)
if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self
@property

View File

@ -111,13 +111,15 @@ class PoolerConfig:
def get_use_activation(o: object):
if softmax := getattr(o, "softmax", None) is not None:
logger.warning_once(
"softmax will be deprecated, please use use_activation instead."
"softmax will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return softmax
if activation := getattr(o, "activation", None) is not None:
logger.warning_once(
"activation will be deprecated, please use use_activation instead."
"activation will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return activation

View File

@ -666,8 +666,9 @@ class VllmConfig:
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.info(
@ -692,22 +693,29 @@ class VllmConfig:
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and self.model_config is not None
):
if self.model_config.pooler_config is not None:
if model_config := self.model_config:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
@ -812,11 +820,6 @@ class VllmConfig:
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
@ -1006,7 +1009,7 @@ class VllmConfig:
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
@ -1047,8 +1050,14 @@ class VllmConfig:
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
decode_query_len = 1
if (
self.speculative_config
and self.speculative_config.num_speculative_tokens
):
decode_query_len += self.speculative_config.num_speculative_tokens
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock = threading.Lock()
def memory_fence():
"""
Full memory barrier for shared memory synchronization.
Ensures all prior memory writes are visible to other processes before
any subsequent reads. This is critical for lock-free producer-consumer
patterns using shared memory.
Implementation acquires and immediately releases a lock. Python's
threading.Lock provides sequentially consistent memory barrier semantics
across all major platforms (POSIX, Windows). This is a lightweight
operation (~20ns) that guarantees:
- All stores before the barrier are visible to other threads/processes
- All loads after the barrier see the latest values
"""
# Lock acquire/release provides full memory barrier semantics.
# Using context manager ensures lock release even on exceptions.
with _memory_fence_lock:
pass
def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")
@ -414,6 +442,10 @@ class MessageQueue:
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest read flags from readers.
# Without this, we may read stale flags from our CPU cache and
# spin indefinitely even though readers have completed.
memory_fence()
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
@ -458,6 +490,10 @@ class MessageQueue:
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
# Memory fence ensures the write is visible to readers on other cores
# before we proceed. Without this, readers may spin indefinitely
# waiting for a write that's stuck in our CPU's store buffer.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
@ -473,6 +509,10 @@ class MessageQueue:
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest writes from the writer.
# Without this, we may read stale flags from our CPU cache
# and spin indefinitely even though writer has updated them.
memory_fence()
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
@ -513,6 +553,10 @@ class MessageQueue:
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
# Memory fence ensures the read flag is visible to the writer.
# Without this, writer may not see our read completion and
# could wait indefinitely for all readers to finish.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
self._read_spin_timer.record_activity()

View File

@ -491,6 +491,9 @@ async def transfer_layer(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()

View File

@ -27,7 +27,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
LMCacheAsyncLookupServer,
)
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
@ -683,7 +683,7 @@ class LMCacheConnectorV1Impl:
self.api_server = InternalAPIServer(self)
self.api_server.start()
# Launch plugins
self.plugin_launcher = PluginLauncher(
self.plugin_launcher = RuntimePluginLauncher(
self.config,
role,
self.worker_count,

View File

@ -1586,6 +1586,8 @@ def destroy_distributed_environment():
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
# Reset environment variable cache
envs.disable_envs_cache()
# Ensure all objects are not frozen before cleanup
gc.unfreeze()

View File

@ -71,7 +71,6 @@ from vllm.config.model import (
LogprobsMode,
ModelDType,
RunnerOption,
TaskOption,
TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
@ -360,7 +359,6 @@ class EngineArgs:
hf_config_path: str | None = ModelConfig.hf_config_path
runner: RunnerOption = ModelConfig.runner
convert: ConvertOption = ModelConfig.convert
task: TaskOption | None = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
@ -373,9 +371,8 @@ class EngineArgs:
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: int | None = 0
seed: int = ModelConfig.seed
max_model_len: int | None = ModelConfig.max_model_len
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
cudagraph_capture_sizes: list[int] | None = (
CompilationConfig.cudagraph_capture_sizes
)
@ -463,7 +460,6 @@ class EngineArgs:
MultiModalConfig, "media_io_kwargs"
)
mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_processor_cache_type: MMCacheType | None = (
MultiModalConfig.mm_processor_cache_type
@ -559,9 +555,6 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
# DEPRECATED
enable_multimodal_encoder_data_parallel: bool = False
logits_processors: list[str | type[LogitsProcessor]] | None = (
ModelConfig.logits_processors
)
@ -629,7 +622,6 @@ class EngineArgs:
model_group.add_argument("--model", **model_kwargs["model"])
model_group.add_argument("--runner", **model_kwargs["runner"])
model_group.add_argument("--convert", **model_kwargs["convert"])
model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
model_group.add_argument(
@ -883,11 +875,6 @@ class EngineArgs:
parallel_group.add_argument(
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
)
parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel",
action="store_true",
deprecated=True,
)
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
@ -961,9 +948,6 @@ class EngineArgs:
multimodal_group.add_argument(
"--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
)
multimodal_group.add_argument(
"--disable-mm-preprocessor-cache", action="store_true", deprecated=True
)
multimodal_group.add_argument(
"--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
)
@ -1121,15 +1105,6 @@ class EngineArgs:
compilation_group.add_argument(
"--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
)
compilation_kwargs["cudagraph_capture_sizes"]["help"] = (
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
)
compilation_group.add_argument(
"--cuda-graph-sizes",
**compilation_kwargs["cudagraph_capture_sizes"],
deprecated=True,
)
compilation_group.add_argument(
"--max-cudagraph-capture-size",
**compilation_kwargs["max_cudagraph_capture_size"],
@ -1202,62 +1177,20 @@ class EngineArgs:
if is_gguf(self.model):
self.quantization = self.load_format = "gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
# doesn't affect the user process.
if self.seed is None:
logger.warning_once(
"`seed=None` is equivalent to `seed=0` in V1 Engine. "
"You will no longer be allowed to pass `None` in v0.13.",
scope="local",
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.warning(
"The global random seed is set to %d. Since "
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
"affect the random state of the Python process that "
"launched vLLM.",
self.seed,
)
self.seed = 0
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.warning(
"The global random seed is set to %d. Since "
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
"affect the random state of the Python process that "
"launched vLLM.",
self.seed,
)
if self.disable_mm_preprocessor_cache:
logger.warning_once(
"`--disable-mm-preprocessor-cache` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb 0` instead.",
scope="local",
)
self.mm_processor_cache_gb = 0
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
logger.warning_once(
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb %d` instead.",
envs.VLLM_MM_INPUT_CACHE_GIB,
scope="local",
)
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
if self.enable_multimodal_encoder_data_parallel:
logger.warning_once(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-encoder-tp-mode data` instead.",
scope="local",
)
self.mm_encoder_tp_mode = "data"
return ModelConfig(
model=self.model,
hf_config_path=self.hf_config_path,
runner=self.runner,
convert=self.convert,
task=self.task,
tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
@ -1741,18 +1674,6 @@ class EngineArgs:
# Compilation config overrides
compilation_config = copy.deepcopy(self.compilation_config)
if self.cuda_graph_sizes is not None:
logger.warning(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
"instead."
)
if compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"cuda_graph_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
if self.cudagraph_capture_sizes is not None:
if compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
@ -1862,6 +1783,7 @@ class EngineArgs:
except Exception:
# This is only used to set default_max_num_batched_tokens
device_memory = 0
device_name = ""
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
@ -1926,16 +1848,6 @@ class EngineArgs:
default_chunked_prefill = model_config.is_chunked_prefill_supported
default_prefix_caching = model_config.is_prefix_caching_supported
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
default_prefix_caching = False
logger.warning_once(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default.",
scope="local",
)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
@ -2121,11 +2033,13 @@ def human_readable_int(value):
"k": 10**3,
"m": 10**6,
"g": 10**9,
"t": 10**12,
}
binary_multiplier = {
"K": 2**10,
"M": 2**20,
"G": 2**30,
"T": 2**40,
}
number, suffix = match.groups()

View File

@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints.
# These constants help mitigate header abuse attacks
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
H11_MAX_HEADER_COUNT_DEFAULT = 256
MCP_PREFIX = "mcp_"

View File

@ -19,6 +19,7 @@ from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.parser.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
@ -303,7 +304,7 @@ class ParsableContext(ConversationContext):
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
id=f"mcpo_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
@ -385,6 +386,9 @@ class ParsableContext(ConversationContext):
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
# change this to a mcp_ function call
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
self.parser.response_messages[-1] = last_msg
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
elif last_msg.name == "web_search_preview":

View File

@ -9,7 +9,7 @@ import cloudpickle
import torch.nn as nn
from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from typing_extensions import TypeVar
from vllm.beam_search import (
BeamSearchInstance,
@ -73,7 +73,6 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tokenizers.hf import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
@ -199,7 +198,7 @@ class LLM:
quantization: QuantizationMethods | None = None,
revision: str | None = None,
tokenizer_revision: str | None = None,
seed: int | None = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
@ -367,16 +366,6 @@ class LLM:
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer = tokenizer
else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
self.llm_engine.reset_mm_cache()

View File

@ -176,7 +176,7 @@ class FrontendArgs:
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
"""Enable the `/tokenizer_info` endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If True, log model outputs (generations).

View File

@ -51,7 +51,11 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
@ -380,6 +384,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -1120,6 +1126,10 @@ class OpenAIServingChat(OpenAIServing):
# if the model is finished generating
else:
# check for error finish reason and abort streaming
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request_id)
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
@ -1287,6 +1297,8 @@ class OpenAIServingChat(OpenAIServing):
delta=False,
)
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
@ -1327,6 +1339,9 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request)
for output in final_res.outputs:
# check for error finish reason and raise GenerationError
# finish_reason='error' indicates a retryable request-level internal error
self._raise_if_error(output.finish_reason, request_id)
token_ids = output.token_ids
out_logprobs = output.logprobs
tool_call_info = None

View File

@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import (
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason = output.finish_reason
stop_reason = output.stop_reason
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing):
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing):
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:

View File

@ -133,6 +133,15 @@ from vllm.utils.async_utils import (
from vllm.utils.collection_utils import is_list_of
from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__)
CompletionLikeRequest: TypeAlias = (
@ -456,6 +465,29 @@ class OpenAIServing:
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
# check for error finish reason and abort beam search
if result.outputs[0].finish_reason == "error":
# yield error output and terminate beam search
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
return
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
all_beams_token_id.extend(list(logprobs.keys()))
@ -780,6 +812,35 @@ class OpenAIServing:
)
return json_str
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
"""Raise GenerationError if finish_reason indicates an error."""
if finish_reason == "error":
logger.error(
"Request %s failed with an internal error during generation",
request_id,
)
raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response(
self, e: GenerationError
) -> str:
"""Convert GenerationError to streaming error response."""
return self.create_streaming_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
async def _check_model(
self,
request: AnyRequest,
@ -1339,6 +1400,7 @@ class OpenAIServing:
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
prompt_text, _, _ = self._get_prompt_components(request_prompt)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(

View File

@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import (
)
from openai.types.responses.tool import Mcp, Tool
from openai_harmony import Message as OpenAIHarmonyMessage
from pydantic import TypeAdapter
from vllm import envs
from vllm.engine.protocol import EngineClient
@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import (
ResponseUsage,
StreamingResponsesResponse,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import (
construct_input_messages,
@ -541,6 +545,8 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(str(e))
@ -648,6 +654,8 @@ class OpenAIServingResponses(OpenAIServing):
status = "incomplete"
elif context.finish_reason == "abort":
status = "cancelled"
else:
self._raise_if_error(context.finish_reason, request.request_id)
else:
status = "incomplete"
elif isinstance(context, ParsableContext):
@ -673,6 +681,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]
# finish_reason='error' indicates retryable internal error
self._raise_if_error(final_output.finish_reason, request.request_id)
output = self._make_response_output_items(request, final_output, tokenizer)
if request.enable_response_messages:
@ -1066,6 +1077,8 @@ class OpenAIServingResponses(OpenAIServing):
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
@ -1089,6 +1102,8 @@ class OpenAIServingResponses(OpenAIServing):
):
try:
response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
@ -1227,6 +1242,8 @@ class OpenAIServingResponses(OpenAIServing):
continue
if ctx.last_output.outputs:
output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id)
if reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
@ -1522,6 +1539,9 @@ class OpenAIServingResponses(OpenAIServing):
async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext)
# finish_reason='error' indicates a retryable error
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start():
current_output_index += 1
sent_output_item_added = False
@ -2016,18 +2036,25 @@ class OpenAIServingResponses(OpenAIServing):
)
)
async for event_data in processer(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
_increment_sequence_number_and_return,
):
yield event_data
try:
async for event_data in processer(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
_increment_sequence_number_and_return,
):
yield event_data
except GenerationError as e:
error_json = self._convert_generation_error_to_streaming_response(e)
yield _increment_sequence_number_and_return(
TypeAdapter(StreamingResponsesResponse).validate_json(error_json)
)
return
async def empty_async_generator():
# A hack to trick Python to think this is a generator but

View File

@ -22,6 +22,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
from vllm import envs
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ResponseInputOutputItem,
@ -44,13 +45,13 @@ def make_response_output_items_from_parsable_context(
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"mcp_{random_uuid()}",
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
type=f"{MCP_PREFIX}call",
status="completed",
output=message.output,
# TODO: support error output
@ -98,12 +99,63 @@ def construct_input_messages(
if isinstance(request_input, str):
messages.append({"role": "user", "content": request_input})
else:
for item in request_input:
messages.append(construct_chat_message_with_tool_call(item))
input_messages = construct_chat_messages_with_tool_call(request_input)
messages.extend(input_messages)
return messages
def construct_chat_message_with_tool_call(
def _maybe_combine_reasoning_and_tool_call(
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionMessageParam | None:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if not (
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
):
return None
if len(messages) == 0:
return None
last_message = messages[-1]
if not (
last_message.get("role") == "assistant"
and last_message.get("reasoning") is not None
):
return None
last_message["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
]
return last_message
def construct_chat_messages_with_tool_call(
input_messages: list[ResponseInputOutputItem],
) -> list[ChatCompletionMessageParam]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages: list[ChatCompletionMessageParam] = []
for item in input_messages:
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
if maybe_combined_message is not None:
messages[-1] = maybe_combined_message
else:
messages.append(_construct_single_message_from_response_item(item))
return messages
def _construct_single_message_from_response_item(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
if isinstance(item, ResponseFunctionToolCall):

View File

@ -72,10 +72,9 @@ if TYPE_CHECKING:
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MEDIA_CONNECTOR: str = "http"
VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["ieee", "tf32"] = "ieee"
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
@ -457,11 +456,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
or "12.9",
# Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision
# Accepted values:
# - "ieee" (default): force full IEEE FP32 matmul precision.
# - "tf32": enable TensorFloat32-based fast matmul.
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
"VLLM_FLOAT32_MATMUL_PRECISION",
"highest",
["highest", "high", "medium"],
"ieee",
["ieee", "tf32"],
case_sensitive=False,
),
# Maximum number of compilation jobs to run in parallel.
@ -786,9 +787,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# imported at runtime.
# If a non-existing backend is used, an AssertionError will be thrown.
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser(
@ -1580,6 +1578,12 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def _is_envs_cache_enabled() -> bool:
"""Checked if __getattr__ is wrapped with functools.cache"""
global __getattr__
return hasattr(__getattr__, "cache_clear")
def enable_envs_cache() -> None:
"""
Enables caching of environment variables. This is useful for performance
@ -1590,6 +1594,9 @@ def enable_envs_cache() -> None:
runtime overhead. This also means that environment variables should NOT
be updated after the service is initialized.
"""
if _is_envs_cache_enabled():
# Avoid wrapping functools.cache multiple times
return
# Tag __getattr__ with functools.cache
global __getattr__
__getattr__ = functools.cache(__getattr__)
@ -1599,6 +1606,17 @@ def enable_envs_cache() -> None:
__getattr__(key)
def disable_envs_cache() -> None:
"""
Resets the environment variables cache. It could be used to isolate environments
between unit tests.
"""
global __getattr__
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
if _is_envs_cache_enabled():
__getattr__ = __getattr__.__wrapped__
def __dir__():
return list(environment_variables.keys())
@ -1661,7 +1679,6 @@ def compile_factors() -> dict[str, object]:
"VLLM_MEDIA_CONNECTOR",
"VLLM_ASSETS_CACHE",
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
"VLLM_MM_INPUT_CACHE_GIB",
"VLLM_WORKER_MULTIPROC_METHOD",
"VLLM_ENABLE_V1_MULTIPROCESSING",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",

View File

@ -4,7 +4,10 @@
from contextlib import contextmanager
from typing import Any
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
@ -49,6 +52,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"RoutingMethodType",
"SharedFusedMoE",
"activation_without_mul",
"override_config",

Some files were not shown because too many files have changed in this diff Show More