mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 01:07:05 +08:00
Merge remote-tracking branch 'origin/main' into fix-async-spec-penalty
This commit is contained in:
commit
8c0779a646
@ -15,6 +15,21 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build arm64 wheel - CUDA 13.0"
|
||||
depends_on: ~
|
||||
id: build-wheel-arm64-cuda-13-0
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# aarch64 build
|
||||
- label: "Build arm64 CPU wheel"
|
||||
depends_on: ~
|
||||
@ -25,7 +40,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -39,7 +54,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -52,7 +67,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
|
||||
@ -372,6 +372,17 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
|
||||
|
||||
# keep only "official" files for a non-nightly version (specifed by cli args)
|
||||
PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$")
|
||||
if PY_VERSION_RE.match(version):
|
||||
# upload-wheels.sh ensures no "dev" is in args.version
|
||||
wheel_files = list(
|
||||
filter(lambda x: version in x and "dev" not in x, wheel_files)
|
||||
)
|
||||
print(f"Non-nightly version detected, wheel files used: {wheel_files}")
|
||||
else:
|
||||
print("Nightly version detected, keeping all wheel files.")
|
||||
|
||||
# Generate index and metadata, assuming wheels and indices are stored as:
|
||||
# s3://vllm-wheels/{version}/<wheel files>
|
||||
# s3://vllm-wheels/<anything>/<index files>
|
||||
|
||||
@ -36,6 +36,11 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run model tests
|
||||
docker exec cpu-test bash -c "
|
||||
set -e
|
||||
pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test bash -c "
|
||||
set -e
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
||||
THRESHOLD=${1:-0.25}
|
||||
NUM_Q=${2:-1319}
|
||||
PORT=${3:-8030}
|
||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 600 bash -c '
|
||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
||||
sleep 1
|
||||
done'
|
||||
}
|
||||
|
||||
MODEL="deepseek-ai/DeepSeek-V2-lite"
|
||||
|
||||
# Set BACKENDS based on platform
|
||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
||||
# ROCm platform
|
||||
BACKENDS=("allgather_reducescatter")
|
||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
||||
export VLLM_ROCM_MOE_PADDING=0
|
||||
else
|
||||
# Non-ROCm platform (CUDA/other)
|
||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
||||
kill "${SERVER_PID}" 2>/dev/null || true
|
||||
for _ in {1..20}; do
|
||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 2 \
|
||||
--data-parallel-size 2 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--port $PORT &
|
||||
SERVER_PID=$!
|
||||
wait_for_server $PORT
|
||||
|
||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
||||
OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json"
|
||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
||||
python3 - <<PY
|
||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
||||
PY
|
||||
|
||||
cleanup
|
||||
SERVER_PID=
|
||||
sleep 1
|
||||
PORT=$((PORT+1))
|
||||
done
|
||||
@ -50,7 +50,6 @@ for BACK in "${BACKENDS[@]}"; do
|
||||
--data-parallel-size 2 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--port $PORT &
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
||||
THRESHOLD=${1:-0.25}
|
||||
NUM_Q=${2:-1319}
|
||||
PORT=${3:-8040}
|
||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 600 bash -c '
|
||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
||||
sleep 1
|
||||
done'
|
||||
}
|
||||
|
||||
MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||
|
||||
# Set BACKENDS based on platform
|
||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
||||
# ROCm platform
|
||||
BACKENDS=("allgather_reducescatter")
|
||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
||||
export VLLM_ROCM_MOE_PADDING=0
|
||||
else
|
||||
# Non-ROCm platform (CUDA/other)
|
||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
||||
kill "${SERVER_PID}" 2>/dev/null || true
|
||||
for _ in {1..20}; do
|
||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--port $PORT &
|
||||
SERVER_PID=$!
|
||||
wait_for_server $PORT
|
||||
|
||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
||||
OUT="${OUT_DIR}/${TAG}_${BACK}.json"
|
||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
||||
python3 - <<PY
|
||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
||||
PY
|
||||
|
||||
cleanup
|
||||
SERVER_PID=
|
||||
sleep 1
|
||||
PORT=$((PORT+1))
|
||||
done
|
||||
@ -34,9 +34,10 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then
|
||||
fi
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# we also accept params as manylinux tag
|
||||
# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels
|
||||
manylinux_version="manylinux_2_31"
|
||||
manylinux_version="${1:-manylinux_2_31}"
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
if [[ "$wheel" != *"linux"* ]]; then
|
||||
@ -96,8 +97,11 @@ if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]];
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/"
|
||||
fi
|
||||
|
||||
# copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
# re-generate and copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
if [[ "$version" != *"dev"* ]]; then
|
||||
echo "Uploading indices to overwrite /$pure_version/"
|
||||
echo "Re-generating indices for /$pure_version/"
|
||||
rm -rf "$INDICES_OUTPUT_DIR/*"
|
||||
mkdir -p "$INDICES_OUTPUT_DIR"
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
|
||||
fi
|
||||
|
||||
@ -326,10 +326,10 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
- label: V1 Test e2e + engine # 65min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -435,7 +435,7 @@ steps:
|
||||
|
||||
- label: Examples Test # 30min
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -455,7 +455,6 @@ steps:
|
||||
# for multi-modal models
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
# for pooling models
|
||||
|
||||
@ -1379,22 +1379,4 @@ steps:
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
|
||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
||||
|
||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
@ -32,12 +32,11 @@ def benchmark_propose(args):
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
task="generate",
|
||||
max_model_len=args.num_token + args.num_spec_token,
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
seed=0,
|
||||
trust_remote_code=False,
|
||||
)
|
||||
proposer = NgramProposer(
|
||||
|
||||
@ -574,7 +574,7 @@ async def benchmark(
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
||||
"Total token throughput (tok/s):", metrics.total_token_throughput
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
|
||||
in MLA (Multi-head Latent Attention) prefill.
|
||||
|
||||
This validates that the optimization from commit 8d4142bd is beneficial across
|
||||
various batch sizes, not just the originally tested batch size of 32768.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
# DeepSeek-V3 MLA dimensions
|
||||
NUM_HEADS = 128
|
||||
QK_NOPE_HEAD_DIM = 128
|
||||
PE_DIM = 64
|
||||
|
||||
|
||||
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||
"""Original torch.cat approach with expand."""
|
||||
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
|
||||
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||
"""Optimized direct copy approach (avoids expand + cat overhead)."""
|
||||
k = torch.empty(
|
||||
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||
dtype=k_nope.dtype,
|
||||
device=k_nope.device,
|
||||
)
|
||||
k[..., : k_nope.shape[-1]] = k_nope
|
||||
k[..., k_nope.shape[-1] :] = k_pe
|
||||
return k
|
||||
|
||||
|
||||
def benchmark_method(
|
||||
method: Callable,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
num_warmup: int = 10,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
"""Benchmark a concatenation method and return mean latency in ms."""
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
_ = method(k_nope, k_pe)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
_ = method(k_nope, k_pe)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / num_iters * 1000 # Convert to ms
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_benchmark(dtype: torch.dtype, dtype_name: str):
|
||||
"""Run benchmark for a specific dtype."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Batch sizes to test (powers of 2 from 32 to 65536)
|
||||
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
||||
|
||||
print("=" * 80)
|
||||
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
|
||||
f"k_pe=[B, 1, {PE_DIM}]"
|
||||
)
|
||||
print(f"dtype: {dtype_name}")
|
||||
print()
|
||||
print(
|
||||
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
|
||||
f"{'Speedup':>8} | {'Reduction':>10}"
|
||||
)
|
||||
print("-" * 70)
|
||||
|
||||
results = []
|
||||
for batch_size in batch_sizes:
|
||||
# Create input tensors (generate in float32 then convert for FP8 compatibility)
|
||||
k_nope = torch.randn(
|
||||
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
|
||||
).to(dtype)
|
||||
k_pe = torch.randn(
|
||||
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
|
||||
).to(dtype)
|
||||
|
||||
# Benchmark both methods
|
||||
cat_time = benchmark_method(cat_method, k_nope, k_pe)
|
||||
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
|
||||
|
||||
speedup = cat_time / direct_time
|
||||
reduction = (1 - direct_time / cat_time) * 100
|
||||
|
||||
results.append((batch_size, cat_time, direct_time, speedup, reduction))
|
||||
|
||||
print(
|
||||
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
|
||||
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
# Summary statistics
|
||||
speedups = [r[3] for r in results]
|
||||
print("\nSpeedup summary:")
|
||||
print(f" Min: {min(speedups):.2f}x")
|
||||
print(f" Max: {max(speedups):.2f}x")
|
||||
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
|
||||
|
||||
# Find crossover point
|
||||
crossover_batch = None
|
||||
for batch_size, _, _, speedup, _ in results:
|
||||
if speedup >= 1.0:
|
||||
crossover_batch = batch_size
|
||||
break
|
||||
|
||||
print("\nConclusion:")
|
||||
if crossover_batch:
|
||||
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
|
||||
# Filter for large batches (>= 512 which is typical for prefill)
|
||||
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
|
||||
if large_batch_speedups:
|
||||
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
|
||||
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
|
||||
print(" - MLA prefill typically uses large batches, so optimization is effective")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
# Test bfloat16
|
||||
print("\n")
|
||||
run_benchmark(torch.bfloat16, "bfloat16")
|
||||
|
||||
# Test float8_e4m3fn
|
||||
print("\n")
|
||||
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -99,7 +99,6 @@ def benchmark_mrope(
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
|
||||
@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
base = 10000
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
|
||||
@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
|
||||
run_python(_VLLM_TORCH_GOMP_PATH
|
||||
"
|
||||
import os, glob
|
||||
try:
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
torch_libs = os.path.join(site_root, 'torch.libs')
|
||||
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
|
||||
except:
|
||||
print('')
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
|
||||
# Search both torch.libs and torch/lib
|
||||
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
|
||||
candidates = []
|
||||
for root in roots:
|
||||
if not os.path.isdir(root):
|
||||
continue
|
||||
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
|
||||
|
||||
print(candidates[0] if candidates else '')
|
||||
"
|
||||
"failed to probe torch.libs for libgomp")
|
||||
"failed to probe for libgomp")
|
||||
|
||||
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
|
||||
return()
|
||||
|
||||
@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata(
|
||||
input.casual = casual;
|
||||
input.isa = isa;
|
||||
input.enable_kv_split = enable_kv_split;
|
||||
TORCH_CHECK(casual, "Only supports casual mask for now.");
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
|
||||
@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
|
||||
@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
}
|
||||
@ -84,7 +84,7 @@ Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
Total token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
|
||||
@ -21,30 +21,20 @@ The mental model is that server-level metrics help explain the values of request
|
||||
|
||||
### v1 Metrics
|
||||
|
||||
In v1, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix:
|
||||
In v1, an extensive set of metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix, for example:
|
||||
|
||||
- `vllm:num_requests_running` (Gauge) - Number of requests currently running.
|
||||
- `vllm:num_requests_waiting` (Gauge) - Number of requests currently waiting.
|
||||
- `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (0–1).
|
||||
- `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries.
|
||||
- `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits.
|
||||
- `vllm:mm_cache_queries` (Counter) - (For multimodal models) Number of multimodal cache queries.
|
||||
- `vllm:mm_cache_hits` (Counter) - (For multimodal models) Number of multimodal cache hits.
|
||||
- `vllm:num_preemptions_total` (Counter) - Number of preemptions.
|
||||
- `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed.
|
||||
- `vllm:generation_tokens_total` (Counter) - Total number of generated tokens.
|
||||
- `vllm:iteration_tokens_total` (Histogram) - Histogram of tokens processed in each engine step.
|
||||
- `vllm:cache_config_info` (Gauge) - Information about the cache configuration.
|
||||
- `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason).
|
||||
- `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts.
|
||||
- `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts.
|
||||
- `vllm:request_params_n` (Histogram) - Histogram of request parameter n.
|
||||
- `vllm:request_params_max_tokens` - (Histogram) - Histogram of max_tokens parameter in requests.
|
||||
- `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT).
|
||||
- `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency.
|
||||
- `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency.
|
||||
- `vllm:request_queue_time_seconds` (Histogram) - Time spent in the queue.
|
||||
- `vllm:request_inference_time_seconds` (Histogram) - Request inference time.
|
||||
- `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time.
|
||||
- `vllm:request_decode_time_seconds` (Histogram) - Request decode time.
|
||||
|
||||
|
||||
@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you
|
||||
## Deprecation announcement
|
||||
|
||||
!!! warning "Deprecations"
|
||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
|
||||
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0.
|
||||
- `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||
|
||||
@ -61,7 +61,7 @@ Now let´s see an example for each of the cases, starting with the `choice`, as
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
||||
|
||||
The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template:
|
||||
The next example shows how to use the `regex`. The supported regex syntax depends on the structured output backend. For example, `xgrammar`, `guidance`, and `outlines` use Rust-style regex, while `lm-format-enforcer` uses Python's `re` module. The idea is to generate an email address, given a simple regex template:
|
||||
|
||||
??? code
|
||||
|
||||
|
||||
@ -26,3 +26,4 @@ The backends below live **outside** the main `vllm` repository and follow the
|
||||
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
|
||||
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
||||
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
||||
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
||||
|
||||
@ -29,8 +29,27 @@ uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.a
|
||||
|
||||
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
|
||||
|
||||
!!! note
|
||||
Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression).
|
||||
**Install the latest code**
|
||||
|
||||
LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides working pre-built Arm CPU wheels for every commit since `v0.11.2` on <https://wheels.vllm.ai/nightly>. For native CPU wheels, this index should be used:
|
||||
|
||||
* `https://wheels.vllm.ai/nightly/cpu/vllm`
|
||||
|
||||
To install from nightly index, copy the link address of the `*.whl` under this index to run, for example:
|
||||
|
||||
```bash
|
||||
uv pip install -U https://wheels.vllm.ai/c756fb678184b867ed94e5613a529198f1aee423/vllm-0.13.0rc2.dev11%2Bgc756fb678.cpu-cp38-abi3-manylinux_2_31_aarch64.whl # current nightly build (the filename will change!)
|
||||
```
|
||||
|
||||
**Install specific revisions**
|
||||
|
||||
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), specify the full commit hash in the index:
|
||||
https://wheels.vllm.ai/${VLLM_COMMIT}/cpu/vllm .
|
||||
Then, copy the link address of the `*.whl` under this index to run:
|
||||
|
||||
```bash
|
||||
uv pip install -U <wheel-url>
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
@ -81,7 +100,23 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm CPU images.
|
||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
|
||||
|
||||
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
|
||||
Please replace `<version>` in the command below with a specific version string (e.g., `0.12.0`).
|
||||
|
||||
```bash
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v<version>
|
||||
```
|
||||
|
||||
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.
|
||||
|
||||
The latest code can contain bugs and may not be stable. Please use it with caution.
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
149
docs/mkdocs/hooks/generate_metrics.py
Normal file
149
docs/mkdocs/hooks/generate_metrics.py
Normal file
@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger("mkdocs")
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||
DOCS_DIR = ROOT_DIR / "docs"
|
||||
GENERATED_METRICS_DIR = DOCS_DIR / "generated" / "metrics"
|
||||
|
||||
# Files to scan for metric definitions - each will generate a separate table
|
||||
METRIC_SOURCE_FILES = [
|
||||
{"path": "vllm/v1/metrics/loggers.py", "output": "general.md"},
|
||||
{
|
||||
"path": "vllm/v1/spec_decode/metrics.py",
|
||||
"output": "spec_decode.md",
|
||||
},
|
||||
{
|
||||
"path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py",
|
||||
"output": "nixl_connector.md",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class MetricExtractor(ast.NodeVisitor):
|
||||
"""AST visitor to extract metric definitions."""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics: list[dict[str, str]] = []
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
"""Visit function calls to find metric class instantiations."""
|
||||
metric_type = self._get_metric_type(node)
|
||||
if metric_type:
|
||||
name = self._extract_kwarg(node, "name")
|
||||
documentation = self._extract_kwarg(node, "documentation")
|
||||
|
||||
if name:
|
||||
self.metrics.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": metric_type,
|
||||
"documentation": documentation or "",
|
||||
}
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def _get_metric_type(self, node: ast.Call) -> str | None:
|
||||
"""Determine if this call creates a metric and return its type."""
|
||||
metric_type_map = {
|
||||
"_gauge_cls": "gauge",
|
||||
"_counter_cls": "counter",
|
||||
"_histogram_cls": "histogram",
|
||||
}
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
return metric_type_map.get(node.func.attr)
|
||||
return None
|
||||
|
||||
def _extract_kwarg(self, node: ast.Call, key: str) -> str | None:
|
||||
"""Extract a keyword argument value from a function call."""
|
||||
for keyword in node.keywords:
|
||||
if keyword.arg == key:
|
||||
return self._get_string_value(keyword.value)
|
||||
return None
|
||||
|
||||
def _get_string_value(self, node: ast.AST) -> str | None:
|
||||
"""Extract string value from an AST node."""
|
||||
if isinstance(node, ast.Constant):
|
||||
return str(node.value) if node.value is not None else None
|
||||
return None
|
||||
|
||||
|
||||
def extract_metrics_from_file(filepath: Path) -> list[dict[str, str]]:
|
||||
"""Parse a Python file and extract all metric definitions."""
|
||||
try:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source, filename=str(filepath))
|
||||
extractor = MetricExtractor()
|
||||
extractor.visit(tree)
|
||||
return extractor.metrics
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse {filepath}: {e}") from e
|
||||
|
||||
|
||||
def generate_markdown_table(metrics: list[dict[str, str]]) -> str:
|
||||
"""Generate a markdown table from extracted metrics."""
|
||||
if not metrics:
|
||||
return "No metrics found.\n"
|
||||
|
||||
# Sort by type, then by name
|
||||
metrics_sorted = sorted(metrics, key=lambda m: (m["type"], m["name"]))
|
||||
|
||||
lines = []
|
||||
lines.append("| Metric Name | Type | Description |")
|
||||
lines.append("|-------------|------|-------------|")
|
||||
|
||||
for metric in metrics_sorted:
|
||||
name = metric["name"]
|
||||
metric_type = metric["type"].capitalize()
|
||||
doc = metric["documentation"].replace("\n", " ").strip()
|
||||
lines.append(f"| `{name}` | {metric_type} | {doc} |")
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
"""Generate metrics documentation tables from source files."""
|
||||
logger.info("Generating metrics documentation")
|
||||
|
||||
# Create generated directory if it doesn't exist
|
||||
GENERATED_METRICS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_metrics = 0
|
||||
for source_config in METRIC_SOURCE_FILES:
|
||||
source_path = source_config["path"]
|
||||
output_file = source_config["output"]
|
||||
|
||||
filepath = ROOT_DIR / source_path
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(f"Metrics source file not found: {filepath}")
|
||||
|
||||
logger.debug("Extracting metrics from: %s", source_path)
|
||||
metrics = extract_metrics_from_file(filepath)
|
||||
logger.debug("Found %d metrics in %s", len(metrics), source_path)
|
||||
|
||||
# Generate and write the markdown table for this source
|
||||
table_content = generate_markdown_table(metrics)
|
||||
output_path = GENERATED_METRICS_DIR / output_file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(table_content)
|
||||
|
||||
total_metrics += len(metrics)
|
||||
logger.info(
|
||||
"Generated metrics table: %s (%d metrics)",
|
||||
output_path.relative_to(ROOT_DIR),
|
||||
len(metrics),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Total metrics generated: %d across %d files",
|
||||
total_metrics,
|
||||
len(METRIC_SOURCE_FILES),
|
||||
)
|
||||
@ -316,10 +316,13 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
|
||||
|
||||
### Remove softmax from PoolingParams
|
||||
|
||||
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
||||
We are going to remove `softmax` and `activation` from `PoolingParams` in v0.15. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
||||
|
||||
### as_reward_model
|
||||
|
||||
!!! warning
|
||||
We are going to remove `--convert reward` in v0.15, use `--convert embed` instead.
|
||||
|
||||
Pooling models now default support all pooling, you can use it without any settings.
|
||||
|
||||
- Extracting hidden states prefers using `token_embed` task.
|
||||
|
||||
@ -568,7 +568,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
```
|
||||
|
||||
!!! note
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker.py](../../examples/pooling/score/qwen3_reranker.py).
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py).
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
|
||||
@ -24,7 +24,7 @@ There are two distinct modes supported for online deployments - self-contained w
|
||||
|
||||
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
|
||||
|
||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs.
|
||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. When sizing DP deployments, remember that `--max-num-seqs` applies per DP rank.
|
||||
|
||||
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
|
||||
|
||||
@ -80,6 +80,18 @@ When deploying large DP sizes using this method, the API server process can beco
|
||||

|
||||
</figure>
|
||||
|
||||
## Hybrid Load Balancing
|
||||
|
||||
Hybrid load balancing sits between the internal and external approaches. Each node runs its own API server(s) that only queue requests to the data-parallel engines colocated on that node. An upstream load balancer (for example, an ingress controller or traffic router) spreads user requests across those per-node endpoints.
|
||||
|
||||
Enable this mode with `--data-parallel-hybrid-lb` while still launching every node with the global data-parallel size. The key differences from internal load balancing are:
|
||||
|
||||
- You must provide `--data-parallel-size-local` and `--data-parallel-start-rank` so each node knows which ranks it owns.
|
||||
- Not compatible with `--headless` since every node exposes an API endpoint.
|
||||
- Scale `--api-server-count` per node based on the number of local ranks
|
||||
|
||||
In this configuration, each node keeps scheduling decisions local, which reduces cross-node traffic and avoids single node bottlenecks at larger DP sizes.
|
||||
|
||||
## External Load Balancing
|
||||
|
||||
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.
|
||||
|
||||
@ -40,10 +40,12 @@ EP_SIZE = TP_SIZE × DP_SIZE
|
||||
|
||||
Where:
|
||||
|
||||
- `TP_SIZE`: Tensor parallel size (always 1 for now)
|
||||
- `TP_SIZE`: Tensor parallel size
|
||||
- `DP_SIZE`: Data parallel size
|
||||
- `EP_SIZE`: Expert parallel size (computed automatically)
|
||||
|
||||
When EP is enabled, MoE layers use expert parallelism instead of tensor parallelism, while attention layers continue to use tensor parallelism if `TP_SIZE > 1`.
|
||||
|
||||
### Example Command
|
||||
|
||||
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
|
||||
@ -81,7 +83,7 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
|
||||
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
|
||||
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
|
||||
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
|
||||
--api-server-count=8 # Number of API servers for load handling (scaling this out to # local ranks is recommended)
|
||||
|
||||
# Node 2 (Secondary - headless mode, no API server)
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
@ -119,9 +121,6 @@ While MoE models are typically trained so that each expert receives a similar nu
|
||||
|
||||
Enable EPLB with the `--enable-eplb` flag.
|
||||
|
||||
!!! note "Model Support"
|
||||
Currently only DeepSeek V3 architecture is supported.
|
||||
|
||||
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
|
||||
|
||||
### EPLB Parameters
|
||||
@ -134,6 +133,8 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T
|
||||
| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 |
|
||||
| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` |
|
||||
| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` |
|
||||
| `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` |
|
||||
| `policy` | The policy type for expert parallel load balancing | `"default"` |
|
||||
|
||||
For example:
|
||||
|
||||
@ -183,6 +184,26 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
|
||||
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available.
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
- **DeepEP kernels**: The `high_throughput` and `low_latency` kernels are optimized for disaggregated serving and may show poor performance for mixed workloads
|
||||
- **Dual Batch Overlap**: Use `--enable-dbo` to overlap all-to-all communication with compute. See [Dual Batch Overlap](../design/dbo.md) for more details.
|
||||
- **Async scheduling (experimental)**: Try `--async-scheduling` to overlap scheduling with model execution.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- **`non-zero status: 7 cannot register cq buf`**: When using Infiniband/RoCE, make sure host VM and pods show `ulimit -l` "unlimited".
|
||||
- **`init failed for transport: IBGDA`**: The InfiniBand GDA kernel modules are missing. Run `tools/ep_kernels/configure_system_drivers.sh` on each GPU node and reboot. Also fixes error `NVSHMEM API called before NVSHMEM initialization has completed`.
|
||||
- **NVSHMEM peer disconnect**: Usually a networking misconfiguration. If deploying via Kubernetes, verify that every pod runs with `hostNetwork: true`, `securityContext.privileged: true` to access Infiniband.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
- Use simulator flags `VLLM_MOE_ROUTING_SIMULATION_STRATEGY=uniform_random` and `VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1` so token routing is balanced across EP ranks.
|
||||
|
||||
- Increasing `VLLM_MOE_DP_CHUNK_SIZE` may increase throughput by increasing the maximum batch size for inter-rank token transfers. This may cause DeepEP to throw `assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2`, which can be fixed by increasing environment variable `NVSHMEM_QP_DEPTH`.
|
||||
|
||||
## Disaggregated Serving (Prefill/Decode Split)
|
||||
|
||||
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
|
||||
@ -273,3 +294,9 @@ except Exception as e:
|
||||
print(f"❌ Error during disaggregated serving: {e}")
|
||||
print("Check that both prefill and decode instances are running and accessible")
|
||||
```
|
||||
|
||||
### Benchmarking
|
||||
|
||||
- To simulate the decode deployment of disaggregated serving, pass `--kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both"}'` to the `vllm serve` invocation. The connector populates KV cache with random values so decode can be profiled in isolation.
|
||||
|
||||
- **CUDAGraph capture**: Use `--compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'` to enable CUDA graph capture for decode only and save KV cache.
|
||||
|
||||
@ -851,7 +851,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin
|
||||
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
|
||||
popular open-source tools.
|
||||
|
||||
Code example: [examples/pooling/score/jinaai_rerank_client.py](../../examples/pooling/score/jinaai_rerank_client.py)
|
||||
Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py)
|
||||
|
||||
#### Example Request
|
||||
|
||||
|
||||
@ -33,11 +33,19 @@ Then query the endpoint to get the latest metrics from the server:
|
||||
|
||||
The following metrics are exposed:
|
||||
|
||||
??? code
|
||||
## General Metrics
|
||||
|
||||
```python
|
||||
--8<-- "vllm/engine/metrics.py:metrics-definitions"
|
||||
```
|
||||
--8<-- "docs/generated/metrics/general.md"
|
||||
|
||||
## Speculative Decoding Metrics
|
||||
|
||||
--8<-- "docs/generated/metrics/spec_decode.md"
|
||||
|
||||
## NIXL KV Connector Metrics
|
||||
|
||||
--8<-- "docs/generated/metrics/nixl_connector.md"
|
||||
|
||||
## Deprecation Policy
|
||||
|
||||
Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1`
|
||||
but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch,
|
||||
|
||||
@ -422,7 +422,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
|
||||
@ -77,7 +77,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -158,7 +158,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -2031,7 +2031,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -1382,7 +1382,7 @@ def run_generate(
|
||||
model,
|
||||
question: str,
|
||||
image_urls: list[str],
|
||||
seed: int | None,
|
||||
seed: int,
|
||||
tensor_parallel_size: int | None,
|
||||
):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
@ -1416,7 +1416,7 @@ def run_chat(
|
||||
model: str,
|
||||
question: str,
|
||||
image_urls: list[str],
|
||||
seed: int | None,
|
||||
seed: int,
|
||||
tensor_parallel_size: int | None,
|
||||
):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
@ -1494,7 +1494,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# --worker \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<worker_node_ip>
|
||||
#
|
||||
#
|
||||
# Each worker requires a unique VLLM_HOST_IP value.
|
||||
# Keep each terminal session open. Closing a session stops the associated Ray
|
||||
# node and thereby shuts down the entire cluster.
|
||||
@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=...").
|
||||
VLLM_HOST_IP=""
|
||||
for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do
|
||||
arg="${ADDITIONAL_ARGS[$i]}"
|
||||
case "${arg}" in
|
||||
-e)
|
||||
next="${ADDITIONAL_ARGS[$((i + 1))]:-}"
|
||||
if [[ "${next}" == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${next#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
;;
|
||||
-eVLLM_HOST_IP=* | VLLM_HOST_IP=*)
|
||||
VLLM_HOST_IP="${arg#*=}"
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent.
|
||||
if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then
|
||||
if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then
|
||||
echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})."
|
||||
echo "Using VLLM_HOST_IP as the head node address."
|
||||
HEAD_NODE_ADDRESS="${VLLM_HOST_IP}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate a unique container name with random suffix.
|
||||
# Docker container names must be unique on each host.
|
||||
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
|
||||
@ -74,36 +102,17 @@ cleanup() {
|
||||
trap cleanup EXIT
|
||||
|
||||
# Build the Ray start command based on the node role.
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# while workers connect to the head's address.
|
||||
RAY_START_CMD="ray start --block"
|
||||
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||
RAY_START_CMD+=" --head --port=6379"
|
||||
RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379"
|
||||
else
|
||||
|
||||
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||
fi
|
||||
|
||||
# Parse VLLM_HOST_IP from additional args if present.
|
||||
# This is needed for multi-NIC configurations where Ray needs explicit IP bindings.
|
||||
VLLM_HOST_IP=""
|
||||
for arg in "${ADDITIONAL_ARGS[@]}"; do
|
||||
if [[ $arg == "-e" ]]; then
|
||||
continue
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}"
|
||||
fi
|
||||
if [[ $arg == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${arg#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Build Ray IP environment variables if VLLM_HOST_IP is set.
|
||||
# These variables ensure Ray binds to the correct network interface on multi-NIC systems.
|
||||
RAY_IP_VARS=()
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_IP_VARS=(
|
||||
-e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
-e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
)
|
||||
fi
|
||||
|
||||
# Launch the container with the assembled parameters.
|
||||
@ -118,6 +127,5 @@ docker run \
|
||||
--shm-size 10.24g \
|
||||
--gpus all \
|
||||
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||
"${RAY_IP_VARS[@]}" \
|
||||
"${ADDITIONAL_ARGS[@]}" \
|
||||
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
|
||||
|
||||
@ -16,7 +16,7 @@ import requests
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --model-impl terratorch
|
||||
# --task embed --trust-remote-code
|
||||
# --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin terratorch_segmentation
|
||||
# --enable-mm-embeds
|
||||
|
||||
@ -305,7 +305,7 @@ def get_query(modality: QueryModality):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def run_encode(model: str, modality: QueryModality, seed: int | None):
|
||||
def run_encode(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
@ -335,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: int | None):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_score(model: str, modality: QueryModality, seed: int | None):
|
||||
def run_score(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
@ -390,7 +390,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -51,6 +51,7 @@ hooks:
|
||||
- docs/mkdocs/hooks/remove_announcement.py
|
||||
- docs/mkdocs/hooks/generate_examples.py
|
||||
- docs/mkdocs/hooks/generate_argparse.py
|
||||
- docs/mkdocs/hooks/generate_metrics.py
|
||||
- docs/mkdocs/hooks/url_schemes.py
|
||||
|
||||
plugins:
|
||||
|
||||
@ -50,4 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
@ -1,2 +1,2 @@
|
||||
lmcache
|
||||
lmcache >= 0.3.10.post1
|
||||
nixl >= 0.7.1 # Required for disaggregated prefill
|
||||
|
||||
@ -138,6 +138,17 @@ elif current_platform.is_rocm():
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
def has_cuda_graph_wrapper_metadata() -> bool:
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module = import_module("torch._inductor.utils")
|
||||
module.CUDAGraphWrapperMetadata # noqa B018
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
@ -145,7 +156,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"inductor_graph_partition",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not has_cuda_graph_wrapper_metadata(),
|
||||
reason="This test requires"
|
||||
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
|
||||
),
|
||||
),
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,7 +12,6 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import _print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
@ -290,7 +288,7 @@ def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
),
|
||||
@ -442,62 +440,3 @@ def test_cudagraph_sizes_post_init(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
|
||||
def test_pass_config_deprecation(caplog_vllm):
|
||||
caplog_vllm.set_level(logging.WARNING)
|
||||
|
||||
# Clear cache to ensure warnings are re-issued
|
||||
_print_warning_once.cache_clear()
|
||||
|
||||
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fusion=True)
|
||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_norm_quant is True
|
||||
assert config.fuse_act_quant is True
|
||||
assert config.enable_fusion is True
|
||||
|
||||
# Test enable_attn_fusion -> fuse_attn_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_attn_fusion=True)
|
||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_attn_quant is True
|
||||
assert config.enable_attn_fusion is True
|
||||
|
||||
# Test enable_noop -> eliminate_noops
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_noop=True)
|
||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||
assert config.eliminate_noops is True
|
||||
assert config.enable_noop is True
|
||||
|
||||
# Test enable_sequence_parallelism -> enable_sp
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_sequence_parallelism=True)
|
||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||
assert config.enable_sp is True
|
||||
assert config.enable_sequence_parallelism is True
|
||||
|
||||
# Test enable_async_tp -> fuse_gemm_comms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_async_tp=True)
|
||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_gemm_comms is True
|
||||
assert config.enable_async_tp is True
|
||||
|
||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_allreduce_rms is True
|
||||
assert config.enable_fi_allreduce_fusion is True
|
||||
|
||||
# Test hash consistency
|
||||
config_old = PassConfig(enable_fusion=True)
|
||||
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
config_old = PassConfig(enable_async_tp=True)
|
||||
config_new = PassConfig(fuse_gemm_comms=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
def __init__(self, head_dim=64, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
@ -741,7 +741,7 @@ class VllmRunner:
|
||||
tokenizer_name: str | None = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
max_model_len: int | None = 1024,
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
|
||||
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
quant_config = None
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
|
||||
nvfp4_fused_moe.create_weights(
|
||||
fml,
|
||||
test_config.num_local_experts,
|
||||
test_config.hidden_size,
|
||||
test_config.intermediate_size,
|
||||
params_dtype=torch.uint8,
|
||||
global_num_experts=test_config.num_experts,
|
||||
)
|
||||
|
||||
fml = fml.to(device)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
test_config.num_local_experts,
|
||||
test_config.intermediate_size,
|
||||
test_config.hidden_size,
|
||||
in_dtype=torch.bfloat16,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
fml.w13_weight.data = w1_q
|
||||
fml.w2_weight.data = w2_q
|
||||
|
||||
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
|
||||
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
|
||||
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
|
||||
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
|
||||
fml.w2_weight_scale.data = (
|
||||
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w2_weight_scale.data.dtype)
|
||||
fml.w13_weight_scale.data = (
|
||||
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w13_weight_scale.data.dtype)
|
||||
|
||||
nvfp4_fused_moe.process_weights_after_loading(fml)
|
||||
|
||||
fml.maybe_init_modular_kernel()
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_dp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
hidden_states = []
|
||||
router_logits = []
|
||||
for layer_idx in range(test_config.num_layers):
|
||||
hidden_states.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
router_logits.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.num_experts),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
out_before_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_before_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
logical_to_physical_map_list = []
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
physical_to_logical_map = shuffled_indices[lidx].to(device)
|
||||
logical_to_physical_map = torch.empty(
|
||||
(num_global_experts,), dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map[physical_to_logical_map] = torch.arange(
|
||||
0, num_global_experts, dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map_list.append(
|
||||
logical_to_physical_map.reshape(num_global_experts, 1)
|
||||
)
|
||||
|
||||
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
|
||||
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
logical_replica_count = torch.ones(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
fml.enable_eplb = True
|
||||
fml.set_eplb_state(
|
||||
lidx,
|
||||
torch.zeros(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
out_after_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_after_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
for lidx in range(test_config.num_layers):
|
||||
torch.testing.assert_close(
|
||||
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("num_layers", [8])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("backend", ["latency", "throughput"])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_tokens: int,
|
||||
backend: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
@ -350,21 +350,35 @@ def test_human_readable_model_len():
|
||||
assert args.max_model_len == 1_000_000
|
||||
args = parser.parse_args(["--max-model-len", "10k"])
|
||||
assert args.max_model_len == 10_000
|
||||
args = parser.parse_args(["--max-model-len", "2g"])
|
||||
assert args.max_model_len == 2_000_000_000
|
||||
args = parser.parse_args(["--max-model-len", "2t"])
|
||||
assert args.max_model_len == 2_000_000_000_000
|
||||
|
||||
# Capital
|
||||
args = parser.parse_args(["--max-model-len", "3K"])
|
||||
assert args.max_model_len == 1024 * 3
|
||||
assert args.max_model_len == 2**10 * 3
|
||||
args = parser.parse_args(["--max-model-len", "10M"])
|
||||
assert args.max_model_len == 2**20 * 10
|
||||
args = parser.parse_args(["--max-model-len", "4G"])
|
||||
assert args.max_model_len == 2**30 * 4
|
||||
args = parser.parse_args(["--max-model-len", "4T"])
|
||||
assert args.max_model_len == 2**40 * 4
|
||||
|
||||
# Decimal values
|
||||
args = parser.parse_args(["--max-model-len", "10.2k"])
|
||||
assert args.max_model_len == 10200
|
||||
# ..truncated to the nearest int
|
||||
args = parser.parse_args(["--max-model-len", "10.212345k"])
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
|
||||
assert args.max_model_len == 10212
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
|
||||
assert args.max_model_len == 10212345
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
|
||||
assert args.max_model_len == 10212345123
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
|
||||
assert args.max_model_len == 10212345123456
|
||||
|
||||
# Invalid (do not allow decimals with binary multipliers)
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
||||
with pytest.raises(ArgumentError):
|
||||
args = parser.parse_args(["--max-model-len", invalid])
|
||||
parser.parse_args(["--max-model-len", invalid])
|
||||
|
||||
228
tests/entrypoints/openai/test_chat_error.py
Normal file
228
tests/entrypoints/openai/test_chat_error.py
Normal file
@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
models,
|
||||
response_role="assistant",
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, request_prompts, engine_prompts
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[[1, 2, 3]],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InternalServerError"
|
||||
assert response.error.message == "Internal server error"
|
||||
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
216
tests/entrypoints/openai/test_completion_error.py
Normal file
216
tests/entrypoints/openai/test_completion_error.py
Normal file
@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_completion = OpenAIServingCompletion(
|
||||
engine,
|
||||
models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_completion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = CompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test prompt",
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = await serving_completion.create_completion(request)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InternalServerError"
|
||||
assert response.error.message == "Internal server error"
|
||||
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = CompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test prompt",
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_completion.create_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
89
tests/entrypoints/openai/test_responses_error.py
Normal file
89
tests/entrypoints/openai/test_responses_error.py
Normal file
@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raise_if_error_raises_generation_error():
|
||||
"""test _raise_if_error raises GenerationError"""
|
||||
# create a minimal OpenAIServing instance
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.model_config = MagicMock()
|
||||
mock_engine.model_config.max_model_len = 100
|
||||
mock_models = MagicMock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=mock_engine,
|
||||
models=mock_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# test that error finish_reason raises GenerationError
|
||||
with pytest.raises(GenerationError) as exc_info:
|
||||
serving._raise_if_error("error", "test-request-id")
|
||||
|
||||
assert str(exc_info.value) == "Internal server error"
|
||||
assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
# test that other finish_reasons don't raise
|
||||
serving._raise_if_error("stop", "test-request-id") # should not raise
|
||||
serving._raise_if_error("length", "test-request-id") # should not raise
|
||||
serving._raise_if_error(None, "test-request-id") # should not raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_generation_error_to_response():
|
||||
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.model_config = MagicMock()
|
||||
mock_engine.model_config.max_model_len = 100
|
||||
mock_models = MagicMock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=mock_engine,
|
||||
models=mock_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# create a GenerationError
|
||||
gen_error = GenerationError("Internal server error")
|
||||
|
||||
# convert to ErrorResponse
|
||||
error_response = serving._convert_generation_error_to_response(gen_error)
|
||||
|
||||
assert isinstance(error_response, ErrorResponse)
|
||||
assert error_response.error.type == "InternalServerError"
|
||||
assert error_response.error.message == "Internal server error"
|
||||
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_generation_error_to_streaming_response():
|
||||
"""test _convert_generation_error_to_streaming_response output"""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.model_config = MagicMock()
|
||||
mock_engine.model_config.max_model_len = 100
|
||||
mock_models = MagicMock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=mock_engine,
|
||||
models=mock_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# create a GenerationError
|
||||
gen_error = GenerationError("Internal server error")
|
||||
|
||||
# convert to streaming error response
|
||||
error_json = serving._convert_generation_error_to_streaming_response(gen_error)
|
||||
|
||||
assert isinstance(error_json, str)
|
||||
assert "Internal server error" in error_json
|
||||
assert "InternalServerError" in error_json
|
||||
@ -116,7 +116,6 @@ def test_mrope(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
|
||||
@ -83,8 +83,12 @@ def test_rotary_embedding(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
@ -150,9 +154,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
@ -177,9 +181,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
|
||||
@ -70,12 +70,12 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
f"{torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
# `cuda_graph_sizes=[16]` to reduce load time.
|
||||
# `cudagraph_capture_sizes=[16]` to reduce load time.
|
||||
with vllm_runner(
|
||||
model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy",
|
||||
cuda_graph_sizes=[16],
|
||||
cudagraph_capture_sizes=[16],
|
||||
) as llm:
|
||||
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
||||
# def check_model(model):
|
||||
|
||||
@ -18,7 +18,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -17,7 +17,6 @@ def test_idefics_multimodal(
|
||||
with vllm_runner(
|
||||
model_name="HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
runner="pooling",
|
||||
task="classify",
|
||||
convert="classify",
|
||||
load_format="dummy",
|
||||
max_model_len=512,
|
||||
@ -86,7 +85,6 @@ def test_gemma_multimodal(
|
||||
with vllm_runner(
|
||||
model_name="google/gemma-3-4b-it",
|
||||
runner="pooling",
|
||||
task="classify",
|
||||
convert="classify",
|
||||
load_format="auto",
|
||||
hf_overrides=update_config,
|
||||
|
||||
@ -92,16 +92,19 @@ def run_test(
|
||||
*,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
prompt_list = PROMPTS * 10
|
||||
expected_list = EXPECTED[model] * 10
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
# TODO (NickLucche) figure out output differences with non-eager and re-enable
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.llm
|
||||
|
||||
@ -120,12 +123,28 @@ def run_test(
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@create_new_process_for_each_test()
|
||||
def test_models(vllm_runner, model) -> None:
|
||||
def test_models(vllm_runner, model, dtype) -> None:
|
||||
run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||
# @create_new_process_for_each_test() does not work for some runners
|
||||
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
|
||||
run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -173,10 +173,7 @@ class _HfExamplesInfo:
|
||||
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AfmoeForCausalLM": _HfExamplesInfo(
|
||||
"arcee-ai/Trinity-Nano",
|
||||
is_available_online=False,
|
||||
),
|
||||
"AfmoeForCausalLM": _HfExamplesInfo("arcee-ai/Trinity-Nano-Preview"),
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
|
||||
|
||||
@ -147,7 +147,7 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Regression test for handling videos with broken frames.
|
||||
This test uses a pre-corrupted video file (assets/corrupted.mp4) that
|
||||
contains broken/unreadable frames to verify the video loader handles
|
||||
contains broken frames to verify the video loader handles
|
||||
them gracefully without crashing and returns accurate metadata.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
@ -177,3 +177,125 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
|
||||
f"Expected fewer than {metadata['total_num_frames']} frames, "
|
||||
f"but loaded {frames.shape[0]} frames"
|
||||
)
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_1")
|
||||
class TestVideoBackendOverride1(VideoLoader):
|
||||
"""Test loader that returns FAKE_OUTPUT_1 to verify backend selection."""
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict]:
|
||||
return FAKE_OUTPUT_1, {"video_backend": "test_video_backend_override_1"}
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_2")
|
||||
class TestVideoBackendOverride2(VideoLoader):
|
||||
"""Test loader that returns FAKE_OUTPUT_2 to verify backend selection."""
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict]:
|
||||
return FAKE_OUTPUT_2, {"video_backend": "test_video_backend_override_2"}
|
||||
|
||||
|
||||
def test_video_media_io_backend_kwarg_override(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that video_backend kwarg can override the VLLM_VIDEO_LOADER_BACKEND
|
||||
environment variable.
|
||||
|
||||
This allows users to dynamically select a different video backend
|
||||
via --media-io-kwargs without changing the global env var, which is
|
||||
useful when plugins set a default backend but a specific request
|
||||
needs a different one.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
# Set the env var to one backend
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_1")
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Without video_backend kwarg, should use env var backend
|
||||
videoio_default = VideoMediaIO(imageio, num_frames=10)
|
||||
frames_default, metadata_default = videoio_default.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_default, FAKE_OUTPUT_1)
|
||||
assert metadata_default["video_backend"] == "test_video_backend_override_1"
|
||||
|
||||
# With video_backend kwarg, should override env var
|
||||
videoio_override = VideoMediaIO(
|
||||
imageio, num_frames=10, video_backend="test_video_backend_override_2"
|
||||
)
|
||||
frames_override, metadata_override = videoio_override.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_override, FAKE_OUTPUT_2)
|
||||
assert metadata_override["video_backend"] == "test_video_backend_override_2"
|
||||
|
||||
|
||||
def test_video_media_io_backend_kwarg_not_passed_to_loader(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""
|
||||
Test that video_backend kwarg is consumed by VideoMediaIO and NOT passed
|
||||
through to the underlying video loader's load_bytes method.
|
||||
|
||||
This ensures the kwarg is properly popped from kwargs before forwarding.
|
||||
"""
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_reject_video_backend_kwarg")
|
||||
class RejectVideoBackendKwargLoader(VideoLoader):
|
||||
"""Test loader that fails if video_backend is passed through."""
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict]:
|
||||
# This should never receive video_backend in kwargs
|
||||
if "video_backend" in kwargs:
|
||||
raise AssertionError(
|
||||
"video_backend should be consumed by VideoMediaIO, "
|
||||
"not passed to loader"
|
||||
)
|
||||
return FAKE_OUTPUT_1, {"received_kwargs": list(kwargs.keys())}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_reject_video_backend_kwarg")
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Even when video_backend is provided, it should NOT be passed to loader
|
||||
videoio = VideoMediaIO(
|
||||
imageio,
|
||||
num_frames=10,
|
||||
video_backend="test_reject_video_backend_kwarg",
|
||||
other_kwarg="should_pass_through",
|
||||
)
|
||||
|
||||
# This should NOT raise AssertionError
|
||||
frames, metadata = videoio.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames, FAKE_OUTPUT_1)
|
||||
# Verify other kwargs are still passed through
|
||||
assert "other_kwarg" in metadata["received_kwargs"]
|
||||
|
||||
|
||||
def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that when video_backend kwarg is None or not provided,
|
||||
VideoMediaIO falls back to VLLM_VIDEO_LOADER_BACKEND env var.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_2")
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Explicit None should fall back to env var
|
||||
videoio_none = VideoMediaIO(imageio, num_frames=10, video_backend=None)
|
||||
frames_none, metadata_none = videoio_none.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_none, FAKE_OUTPUT_2)
|
||||
assert metadata_none["video_backend"] == "test_video_backend_override_2"
|
||||
|
||||
# Not providing video_backend should also fall back to env var
|
||||
videoio_missing = VideoMediaIO(imageio, num_frames=10)
|
||||
frames_missing, metadata_missing = videoio_missing.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2)
|
||||
assert metadata_missing["video_backend"] == "test_video_backend_override_2"
|
||||
|
||||
@ -212,11 +212,11 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
task = "wikitext"
|
||||
rtol = 0.1
|
||||
|
||||
# Smaller cuda_graph_sizes to speed up the test.
|
||||
# Smaller cudagraph_capture_sizes to speed up the test.
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=config.get_model_args(
|
||||
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
|
||||
tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
|
||||
),
|
||||
tasks=task,
|
||||
batch_size=64,
|
||||
|
||||
195
tests/reasoning/test_minimax_m2_append_reasoning_parser.py
Normal file
195
tests/reasoning/test_minimax_m2_append_reasoning_parser.py
Normal file
@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "minimax_m2_append_think"
|
||||
end_token = "</think>"
|
||||
|
||||
# MiniMax M2 model path
|
||||
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def minimax_m2_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MiniMaxM2AppendThinkReasoningParser behavior:
|
||||
# - Prepends <think> to the beginning of the output
|
||||
# - Does NOT separate reasoning and content
|
||||
# - Returns everything as content (with <think> prepended)
|
||||
# - reasoning is always None
|
||||
#
|
||||
# This parser is used when you want to keep the raw output with <think> added
|
||||
# =============================================================================
|
||||
|
||||
# Case: simple output with end token
|
||||
SIMPLE_OUTPUT = {
|
||||
"output": "This is reasoning</think>This is response",
|
||||
"reasoning": None,
|
||||
"content": "<think>This is reasoning</think>This is response",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: output without end token (reasoning in progress)
|
||||
NO_END_TOKEN = {
|
||||
"output": "This is reasoning in progress",
|
||||
"reasoning": None,
|
||||
"content": "<think>This is reasoning in progress",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
# Case: only end token
|
||||
ONLY_END_TOKEN = {
|
||||
"output": "</think>This is response",
|
||||
"reasoning": None,
|
||||
"content": "<think></think>This is response",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: multiple lines
|
||||
MULTIPLE_LINES = {
|
||||
"output": "Line 1\nLine 2</think>Response 1\nResponse 2",
|
||||
"reasoning": None,
|
||||
"content": "<think>Line 1\nLine 2</think>Response 1\nResponse 2",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: empty output (non-streaming prepends <think>)
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": None,
|
||||
"content": "<think>",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
# Case: empty output streaming (no tokens = no output)
|
||||
EMPTY_STREAMING = {
|
||||
"output": "",
|
||||
"reasoning": None,
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
# Case: special characters
|
||||
SPECIAL_CHARS = {
|
||||
"output": "Let me think... 1+1=2</think>Yes!",
|
||||
"reasoning": None,
|
||||
"content": "<think>Let me think... 1+1=2</think>Yes!",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: code in output
|
||||
CODE_OUTPUT = {
|
||||
"output": "```python\nprint('hi')\n```</think>Here's the code.",
|
||||
"reasoning": None,
|
||||
"content": "<think>```python\nprint('hi')\n```</think>Here's the code.",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_OUTPUT,
|
||||
id="simple_output",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_OUTPUT,
|
||||
id="simple_output_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_END_TOKEN,
|
||||
id="no_end_token",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_END_TOKEN,
|
||||
id="no_end_token_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
ONLY_END_TOKEN,
|
||||
id="only_end_token",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
ONLY_END_TOKEN,
|
||||
id="only_end_token_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY,
|
||||
id="empty",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_STREAMING,
|
||||
id="empty_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SPECIAL_CHARS,
|
||||
id="special_chars",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SPECIAL_CHARS,
|
||||
id="special_chars_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
CODE_OUTPUT,
|
||||
id="code_output",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
CODE_OUTPUT,
|
||||
id="code_output_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
minimax_m2_tokenizer,
|
||||
):
|
||||
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
|
||||
# decode everything to tokens
|
||||
output_tokens: list[str] = [
|
||||
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||
minimax_m2_tokenizer
|
||||
)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, output_tokens, streaming=streaming
|
||||
)
|
||||
|
||||
assert reasoning == param_dict["reasoning"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
|
||||
is_reasoning_end = parser.is_reasoning_end(output_ids)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
230
tests/reasoning/test_minimax_m2_reasoning_parser.py
Normal file
230
tests/reasoning/test_minimax_m2_reasoning_parser.py
Normal file
@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "minimax_m2"
|
||||
end_token = "</think>"
|
||||
|
||||
# MiniMax M2 model path
|
||||
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def minimax_m2_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MiniMax M2 specific behavior:
|
||||
# - Model does NOT generate <think> start token
|
||||
# - Model only generates </think> end token
|
||||
# - All content before </think> is reasoning
|
||||
# - All content after </think> is the actual response (content)
|
||||
# =============================================================================
|
||||
|
||||
# Case: reasoning + end token + content (typical case)
|
||||
SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section</think>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: reasoning + end token only (no content after)
|
||||
COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section</think>",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: no end token yet (streaming in progress, all is reasoning)
|
||||
NO_END_TOKEN = {
|
||||
"output": "This is reasoning in progress",
|
||||
"reasoning": "This is reasoning in progress",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
# Case: multiple lines of reasoning
|
||||
MULTIPLE_LINES = {
|
||||
"output": "First line\nSecond line</think>Response first line\nResponse second",
|
||||
"reasoning": "First line\nSecond line",
|
||||
"content": "Response first line\nResponse second",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: only end token (empty reasoning, immediate response)
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "</think>This is the response",
|
||||
"reasoning": "",
|
||||
"content": "This is the response",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: only end token streaming (reasoning is None because it's just the token)
|
||||
SHORTEST_REASONING_STREAMING = {
|
||||
"output": "</think>This is the response",
|
||||
"reasoning": None,
|
||||
"content": "This is the response",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: empty output
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": "",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
# Case: empty streaming
|
||||
EMPTY_STREAMING = {
|
||||
"output": "",
|
||||
"reasoning": None,
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
# Case: long reasoning with special characters
|
||||
SPECIAL_CHARS = {
|
||||
"output": "Let me think... 1+1=2, right?</think>Yes, 1+1=2.",
|
||||
"reasoning": "Let me think... 1+1=2, right?",
|
||||
"content": "Yes, 1+1=2.",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
# Case: reasoning with code blocks
|
||||
CODE_IN_REASONING = {
|
||||
"output": "```python\nprint('hello')\n```</think>Here is the code.",
|
||||
"reasoning": "```python\nprint('hello')\n```",
|
||||
"content": "Here is the code.",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
# Core cases: no start token (MiniMax M2 actual behavior)
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_END_TOKEN,
|
||||
id="no_end_token",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_END_TOKEN,
|
||||
id="no_end_token_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_STREAMING,
|
||||
id="shortest_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY,
|
||||
id="empty",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_STREAMING,
|
||||
id="empty_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SPECIAL_CHARS,
|
||||
id="special_chars",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SPECIAL_CHARS,
|
||||
id="special_chars_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
CODE_IN_REASONING,
|
||||
id="code_in_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
CODE_IN_REASONING,
|
||||
id="code_in_reasoning_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
minimax_m2_tokenizer,
|
||||
):
|
||||
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
|
||||
# decode everything to tokens
|
||||
output_tokens: list[str] = [
|
||||
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||
minimax_m2_tokenizer
|
||||
)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, output_tokens, streaming=streaming
|
||||
)
|
||||
|
||||
assert reasoning == param_dict["reasoning"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
|
||||
is_reasoning_end = parser.is_reasoning_end(output_ids)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_ids)
|
||||
assert content == minimax_m2_tokenizer.convert_tokens_to_ids(
|
||||
minimax_m2_tokenizer.tokenize(param_dict["content"])
|
||||
)
|
||||
else:
|
||||
content = parser.extract_content_ids(output)
|
||||
assert content == []
|
||||
@ -18,47 +18,53 @@ def mistral_tokenizer():
|
||||
return mistral_tokenizer
|
||||
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
INVALID_SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning sectionThis is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
INVALID_COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "This is content",
|
||||
"reasoning": "This is content",
|
||||
"output": "[THINK]This is reasoning",
|
||||
"reasoning": "This is reasoning",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning": None,
|
||||
"content": "This is content",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
INVALID_MULTIPLE_LINES = {
|
||||
"output": "This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
"reasoning": None,
|
||||
"content": "This\nThatThis is the rest\nThat",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
INVALID_SHORTEST_REASONING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
|
||||
@ -78,17 +84,17 @@ MULTIPLE_LINES_WITH_THINK = {
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
INVALID_SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
THINK_NO_END = {
|
||||
"output": "[THINK]This is a reasoning section",
|
||||
@ -98,8 +104,8 @@ THINK_NO_END = {
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": "",
|
||||
"content": None,
|
||||
"reasoning": None,
|
||||
"content": "",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY_STREAMING = {
|
||||
@ -109,47 +115,48 @@ EMPTY_STREAMING = {
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
# Streaming cannot handle new lines at the beginning of the output
|
||||
# because we need to support [THINK]...[/THINK] and [/THINK]...
|
||||
# We cannot know if the text before [THINK] is reasoning content
|
||||
# or not.
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "\nThis is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
INVALID_SIMPLE_REASONING,
|
||||
id="invalid_simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
INVALID_SIMPLE_REASONING,
|
||||
id="invalid_simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
INVALID_COMPLETE_REASONING,
|
||||
id="invalid_complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
INVALID_COMPLETE_REASONING,
|
||||
id="invalid_complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_CONTENT,
|
||||
id="no_content_token",
|
||||
id="no_content",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_REASONING,
|
||||
id="no_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
@ -158,23 +165,23 @@ TEST_CASES = [
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
INVALID_MULTIPLE_LINES,
|
||||
id="invalid_multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
INVALID_MULTIPLE_LINES,
|
||||
id="invalid_multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING,
|
||||
id="shortest",
|
||||
INVALID_SHORTEST_REASONING,
|
||||
id="invalid_shortest",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_streaming",
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING,
|
||||
id="invalid_shortest_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
@ -208,13 +215,13 @@ TEST_CASES = [
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="shortest_with_think",
|
||||
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="invalid_shortest_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_WITH_THINK,
|
||||
id="shortest_with_think_streaming",
|
||||
INVALID_SHORTEST_REASONING_WITH_THINK,
|
||||
id="invalid_shortest_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
@ -316,10 +323,26 @@ def test_mistral_reasoning(
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == mistral_tokenizer.tokenizer.encode(
|
||||
param_dict["content"], bos=False, eos=False
|
||||
# Handle the case where there are tokens outputted before Thinking.
|
||||
# This should not occur if the model is well trained and prompted.
|
||||
if "[THINK]" in param_dict["output"] and not param_dict["output"].startswith(
|
||||
"[THINK]"
|
||||
):
|
||||
before_content = param_dict["output"].split("[THINK]")[0]
|
||||
before_token_ids = mistral_tokenizer.tokenizer.encode(
|
||||
before_content, bos=False, eos=False
|
||||
)
|
||||
left_to_encode = param_dict["content"][len(before_content) :]
|
||||
# Normal situation.
|
||||
else:
|
||||
before_token_ids = []
|
||||
left_to_encode = param_dict["content"]
|
||||
|
||||
content_tokens = parser.extract_content_ids(output_tokens)
|
||||
expected_token_ids = before_token_ids + mistral_tokenizer.tokenizer.encode(
|
||||
left_to_encode, bos=False, eos=False
|
||||
)
|
||||
assert content_tokens == expected_token_ids
|
||||
else:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == []
|
||||
|
||||
@ -3,12 +3,45 @@
|
||||
# for users who do not have any compilers installed on their system
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
merge_base_commit=$(git merge-base HEAD origin/main)
|
||||
echo "Current merge base commit with main: $merge_base_commit"
|
||||
echo "INFO: current merge base commit with main: $merge_base_commit"
|
||||
git show --oneline -s $merge_base_commit
|
||||
|
||||
# test whether the metadata.json url is valid, retry each 3 minutes up to 5 times
|
||||
# this avoids cumbersome error messages & manual retries in case the precompiled wheel
|
||||
# for the given commit is still being built in the release pipeline
|
||||
meta_json_url="https://wheels.vllm.ai/$merge_base_commit/vllm/metadata.json"
|
||||
echo "INFO: will use metadata.json from $meta_json_url"
|
||||
|
||||
for i in {1..5}; do
|
||||
echo "Checking metadata.json URL (attempt $i)..."
|
||||
if curl --fail "$meta_json_url" > metadata.json; then
|
||||
echo "INFO: metadata.json URL is valid."
|
||||
# check whether it is valid json by python
|
||||
if python3 -m json.tool metadata.json; then
|
||||
echo "INFO: metadata.json is valid JSON. Proceeding with the test."
|
||||
else
|
||||
echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
|
||||
exit 1
|
||||
fi
|
||||
break
|
||||
fi
|
||||
# failure handling
|
||||
if [ $i -eq 5 ]; then
|
||||
echo "ERROR: metadata.json URL is still not valid after 5 attempts."
|
||||
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists."
|
||||
echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes."
|
||||
echo " NOTE: If it fails, please report in #sig-ci channel."
|
||||
exit 1
|
||||
else
|
||||
echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..."
|
||||
sleep 180
|
||||
fi
|
||||
done
|
||||
|
||||
set -x
|
||||
|
||||
cd /vllm-workspace/
|
||||
|
||||
# uninstall vllm
|
||||
@ -29,6 +62,6 @@ python3 -c 'import vllm'
|
||||
|
||||
# Check if the clangd log file was created
|
||||
if [ ! -f /tmp/changed.file ]; then
|
||||
echo "changed.file was not created, python only compilation failed"
|
||||
echo "ERROR: changed.file was not created, python only compilation failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -89,64 +89,6 @@ def test_update_config():
|
||||
new_config3 = update_config(config3, {"a": "new_value"})
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "generate", "none", "generate"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
|
||||
("openai/whisper-small", "generate", "none", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_auto_task(
|
||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
||||
):
|
||||
config = ModelConfig(model_id, task="auto")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "pooling", "embed", "embed"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"),
|
||||
("openai/whisper-small", "pooling", "embed", "embed"),
|
||||
],
|
||||
)
|
||||
def test_score_task(
|
||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
||||
):
|
||||
config = ModelConfig(model_id, task="score")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
||||
[
|
||||
("openai/whisper-small", "generate", "none", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_transcription_task(
|
||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
||||
):
|
||||
config = ModelConfig(model_id, task="transcription")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type"),
|
||||
[
|
||||
@ -1085,7 +1027,7 @@ def test_vllm_config_explicit_overrides():
|
||||
)
|
||||
|
||||
# Override one field but not others
|
||||
pass_config = PassConfig(enable_noop=False)
|
||||
pass_config = PassConfig(eliminate_noops=False)
|
||||
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||
config = VllmConfig(
|
||||
model_config=regular_model,
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.envs import (
|
||||
disable_envs_cache,
|
||||
enable_envs_cache,
|
||||
env_list_with_choices,
|
||||
env_set_with_choices,
|
||||
@ -57,6 +58,43 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
envs.__getattr__ = envs.__getattr__.__wrapped__
|
||||
|
||||
|
||||
def test_getattr_with_reset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1")
|
||||
# __getattr__ is not decorated with functools.cache
|
||||
assert not hasattr(envs.__getattr__, "cache_info")
|
||||
|
||||
# Enable envs cache and ignore ongoing environment changes
|
||||
enable_envs_cache()
|
||||
assert envs.VLLM_HOST_IP == "1.1.1.1"
|
||||
# With cache enabled, the environment variable value is cached and unchanged
|
||||
monkeypatch.setenv("VLLM_HOST_IP", "2.2.2.2")
|
||||
assert envs.VLLM_HOST_IP == "1.1.1.1"
|
||||
|
||||
disable_envs_cache()
|
||||
assert envs.VLLM_HOST_IP == "2.2.2.2"
|
||||
# After cache disabled, the environment variable value would be synced
|
||||
# with os.environ
|
||||
monkeypatch.setenv("VLLM_HOST_IP", "3.3.3.3")
|
||||
assert envs.VLLM_HOST_IP == "3.3.3.3"
|
||||
|
||||
|
||||
def test_is_envs_cache_enabled() -> None:
|
||||
assert not envs._is_envs_cache_enabled()
|
||||
enable_envs_cache()
|
||||
assert envs._is_envs_cache_enabled()
|
||||
|
||||
# Only wrap one-layer of cache, so we only need to
|
||||
# call disable once to reset.
|
||||
enable_envs_cache()
|
||||
enable_envs_cache()
|
||||
enable_envs_cache()
|
||||
disable_envs_cache()
|
||||
assert not envs._is_envs_cache_enabled()
|
||||
|
||||
disable_envs_cache()
|
||||
assert not envs._is_envs_cache_enabled()
|
||||
|
||||
|
||||
class TestEnvWithChoices:
|
||||
"""Test cases for env_with_choices function."""
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ class RemoteOpenAIServer:
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: dict[str, str] | None = None,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: float | None = None,
|
||||
override_hf_configs: dict[str, Any] | None = None,
|
||||
@ -283,7 +283,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
||||
child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None],
|
||||
*,
|
||||
env_dict: dict[str, str] | None = None,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: float | None = None,
|
||||
) -> None:
|
||||
|
||||
@ -13,6 +13,7 @@ import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
|
||||
# Detect Blackwell / B200 (compute capability 10.x)
|
||||
try:
|
||||
@ -44,6 +45,7 @@ DEEPEP_BACKENDS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_ep(), reason="These tests require deep_ep to run")
|
||||
@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS)
|
||||
@pytest.mark.xfail(
|
||||
IS_BLACKWELL,
|
||||
|
||||
@ -152,8 +152,8 @@ def run_tests(
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# lock matmul precision to full FP32
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
|
||||
# lock matmul precision to full FP32 (IEEE)
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
outputs: list[tuple[str, list, list]] = []
|
||||
for n, (
|
||||
|
||||
@ -280,9 +280,20 @@ def test_speculators_model_integration(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill"],
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
@ -292,6 +303,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
@ -305,6 +317,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
@ -318,6 +331,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@ -329,6 +343,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
@ -339,6 +354,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
@ -350,6 +366,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@ -361,10 +378,12 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
@ -381,6 +400,7 @@ def test_eagle_correctness(
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
if attn_backend == "TREE_ATTN":
|
||||
@ -389,6 +409,17 @@ def test_eagle_correctness(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
if model_impl == "transformers":
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("5.0.0.dev")
|
||||
if installed < required:
|
||||
pytest.skip(
|
||||
"Eagle3 with the Transformers modeling backend requires "
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@ -448,6 +479,7 @@ def test_eagle_correctness(
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
54
tests/v1/engine/test_init_error_messaging.py
Normal file
54
tests/v1/engine/test_init_error_messaging.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
def test_kv_cache_oom_no_memory():
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
config = MagicMock()
|
||||
config.model_config.max_model_len = 2048
|
||||
|
||||
spec = {
|
||||
"layer_0": FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype="float16",
|
||||
)
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
check_enough_kv_cache_memory(config, spec, 0)
|
||||
|
||||
|
||||
def test_kv_cache_oom_insufficient_memory(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
config = MagicMock()
|
||||
config.model_config.max_model_len = 2048
|
||||
config.cache_config.block_size = 16
|
||||
config.parallel_config.tensor_parallel_size = 1
|
||||
config.parallel_config.pipeline_parallel_size = 1
|
||||
config.parallel_config.decode_context_parallel_size = 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.v1.core.kv_cache_utils.max_memory_usage_bytes",
|
||||
lambda c, s: 100 * 1024**3, # 100 GiB
|
||||
)
|
||||
|
||||
spec = {
|
||||
"layer_0": FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype="float16",
|
||||
)
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB
|
||||
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
test that invalid blocks are evicted from prefix cache to prevent pollution.
|
||||
|
||||
verifies that when sync-loading fails, invalid blocks are removed from the
|
||||
prefix cache hash table so future requests cannot match and reuse corrupted data.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_invalid_blocks_evicted_prevents_cache_pollution(
|
||||
fail_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
verify invalid blocks are evicted to prevent future cache hits.
|
||||
|
||||
scenario:
|
||||
1. request 1 loads externally-computed blocks (sync mode)
|
||||
2. some blocks fail to load and are marked invalid
|
||||
3. with fail policy, invalid blocks should be evicted from prefix cache
|
||||
4. request is marked as FINISHED_ERROR
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
# request 1: will have invalid blocks
|
||||
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
|
||||
fail_scheduler.add_request(request=request1)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request1.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify eviction later
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# cache the blocks to simulate they've been computed and cached
|
||||
# (in real scenario blocks would be cached after compute)
|
||||
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
|
||||
|
||||
# verify block has a hash (is cached) before reporting invalid blocks
|
||||
assert block.block_hash is not None, (
|
||||
f"block {invalid_block_id} should be cached (have a hash) before "
|
||||
f"eviction test, but hash is None"
|
||||
)
|
||||
|
||||
# report invalid blocks
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request finished with error (fail policy)
|
||||
assert request1.status == RequestStatus.FINISHED_ERROR
|
||||
|
||||
# critical assertion: invalid block and all subsequent blocks should be evicted
|
||||
# all blocks from invalid_block_idx onwards become invalid since they were
|
||||
# computed based on the failed block
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is None, (
|
||||
f"block {block_id} at index {idx} should have been evicted "
|
||||
f"(hash reset to None), but hash is {block_obj.block_hash}. "
|
||||
f"All blocks from index {invalid_block_idx} onwards should be evicted "
|
||||
f"since they depend on the invalid block at index {invalid_block_idx}."
|
||||
)
|
||||
|
||||
# verify cache contains exactly the valid blocks (before first affected block)
|
||||
# and none of the invalid blocks (from first affected block onwards)
|
||||
|
||||
# valid blocks: all blocks before invalid_block_idx should be cached
|
||||
for idx in range(invalid_block_idx):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is not None, (
|
||||
f"valid block {block_id} at index {idx} should still be cached "
|
||||
f"(have a hash), but hash is None. Only blocks from index "
|
||||
f"{invalid_block_idx} onwards should be evicted."
|
||||
)
|
||||
|
||||
# invalid blocks: verify they're not in the cached_block_hash_to_block map
|
||||
cached_blocks = (
|
||||
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||
)
|
||||
cached_block_ids = {
|
||||
b.block_id
|
||||
for blocks_val in cached_blocks._cache.values()
|
||||
for b in (
|
||||
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
|
||||
)
|
||||
}
|
||||
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
assert block_id not in cached_block_ids, (
|
||||
f"invalid block {block_id} at index {idx} should not be in cache hash table"
|
||||
)
|
||||
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
|
||||
def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.waiting) == 0
|
||||
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
@ -0,0 +1,454 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Tests for correctness in invalid block handling.
|
||||
|
||||
These tests verify correct behavior in three scenarios:
|
||||
1. Sync recompute case: Blocks should not be freed for running requests
|
||||
that need to recompute invalid blocks
|
||||
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
|
||||
3. Async recompute case: Invalid blocks should not be cached after transfer
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recompute_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='recompute'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_sync_recompute_blocks_not_freed_for_running_requests(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test sync recompute case - blocks must not be freed for running requests.
|
||||
|
||||
When a running request has invalid blocks and retry_policy is 'recompute':
|
||||
1. Request should remain in RUNNING state
|
||||
2. num_computed_tokens should be truncated to invalid block boundary
|
||||
3. Blocks should NOT be freed (request still needs them for recomputation)
|
||||
4. Request should remain in scheduler.requests and scheduler.running
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(recompute_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get the allocated block IDs before invalid blocks are reported
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
|
||||
# store original num_computed_tokens for comparison
|
||||
original_num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False, # not finished - should continue running
|
||||
)
|
||||
|
||||
outputs = recompute_scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# critical assertions for recompute case:
|
||||
|
||||
# 1. request should still be RUNNING (not finished, not aborted)
|
||||
assert request.status == RequestStatus.RUNNING, (
|
||||
f"Request should remain RUNNING for recompute, got {request.status}"
|
||||
)
|
||||
|
||||
# 2. num_computed_tokens should be truncated to first invalid block
|
||||
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_truncated_tokens, (
|
||||
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
assert request.num_computed_tokens < original_num_computed_tokens, (
|
||||
"num_computed_tokens should be reduced after invalid block detection"
|
||||
)
|
||||
|
||||
# 3. no output should be generated (request is still running)
|
||||
# the request should be skipped in the output loop
|
||||
assert len(outputs) == 0 or request.request_id not in [
|
||||
out.request_id for outs in outputs.values() for out in outs.outputs
|
||||
], "No output should be generated for recompute requests"
|
||||
|
||||
# 4. request should still be in running queue
|
||||
assert request in recompute_scheduler.running, (
|
||||
"Request should remain in running queue for recomputation"
|
||||
)
|
||||
|
||||
# 5. request should still be in scheduler.requests (not deleted)
|
||||
assert request.request_id in recompute_scheduler.requests, (
|
||||
"Request should not be deleted from scheduler.requests"
|
||||
)
|
||||
|
||||
# 6. blocks should NOT be freed - verify blocks are still allocated
|
||||
try:
|
||||
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
assert allocated_blocks is not None
|
||||
assert len(allocated_blocks[0]) > 0, (
|
||||
"Blocks should still be allocated for recomputation"
|
||||
)
|
||||
except KeyError:
|
||||
pytest.fail(
|
||||
"Blocks were freed incorrectly! Running requests need their blocks "
|
||||
"to recompute invalid portions."
|
||||
)
|
||||
|
||||
# 7. verify request can be rescheduled in next step
|
||||
scheduler_output_2 = recompute_scheduler.schedule()
|
||||
|
||||
# request should appear in the new schedule to recompute invalid blocks
|
||||
scheduled_req_ids = [
|
||||
req.request_id for req in scheduler_output_2.scheduled_new_reqs
|
||||
]
|
||||
if scheduler_output_2.num_scheduled_tokens:
|
||||
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
|
||||
|
||||
assert (
|
||||
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
|
||||
), "Request should be reschedulable for recomputation"
|
||||
|
||||
|
||||
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
|
||||
"""
|
||||
Test sync fail case - invalid blocks must be evicted from cache.
|
||||
|
||||
When a request fails with policy='fail' and has invalid blocks from sync loading:
|
||||
1. Request should be finished with FINISHED_ERROR
|
||||
2. Invalid blocks should be evicted from the KV cache
|
||||
3. Valid blocks (if shared) should remain in cache
|
||||
4. Future requests should not reuse the invalid blocks
|
||||
|
||||
This test verifies that invalid blocks are properly evicted to prevent
|
||||
cache corruption and reuse of invalid data.
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# verify the block is in the block pool before we report it as invalid
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
assert block is not None
|
||||
|
||||
# report invalid blocks - request should fail
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request is finished with error
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
# verify output is generated
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
# verify the request was removed from scheduler
|
||||
assert request.request_id not in fail_scheduler.requests
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
# critical: verify invalid block was actually freed from cache
|
||||
# this is the key assertion - the invalid block should no longer be
|
||||
# tracked by the KV cache manager for this request
|
||||
# if it's still there, a future request could reuse the invalid data
|
||||
try:
|
||||
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
# if we get here, check if blocks were actually freed
|
||||
if block_ids is not None and len(block_ids[0]) > 0:
|
||||
pytest.fail(
|
||||
f"Invalid blocks still tracked for finished request! "
|
||||
f"Request {request.request_id} should have been freed but "
|
||||
f"still has {len(block_ids[0])} blocks allocated."
|
||||
)
|
||||
# blocks list exists but is empty - this is fine, they were freed
|
||||
except KeyError:
|
||||
# expected - request completely removed from tracking
|
||||
pass
|
||||
|
||||
# critical: verify invalid block was evicted from prefix cache
|
||||
# the block should no longer have a hash (hash is reset on eviction)
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should have been evicted from cache "
|
||||
f"(hash should be None), but hash is still {block.block_hash}"
|
||||
)
|
||||
|
||||
|
||||
def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test async recompute case - invalid blocks not cached after transfer.
|
||||
|
||||
When async KV loading has invalid blocks and retry_policy is 'recompute':
|
||||
1. Blocks are allocated but not cached yet
|
||||
2. When async transfer completes, only valid blocks should be cached
|
||||
3. Invalid blocks should never enter the prefix cache
|
||||
|
||||
This test verifies correctness, the failed_recving_kv_req_ids protection
|
||||
ensures only valid blocks are cached when the transfer completes, and we
|
||||
only evict blocks from cache that are already hashed in the block table.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating async load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be waiting for remote KVs
|
||||
assert len(recompute_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# get the allocated block IDs
|
||||
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify it's not cached yet and stays uncached
|
||||
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# verify block has no hash before invalid blocks are reported
|
||||
assert block.block_hash is None, (
|
||||
"Async loading blocks should not be cached yet (no hash)"
|
||||
)
|
||||
|
||||
# report invalid blocks (transfer not finished yet)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=None, # transfer NOT finished
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
# critical: spy on evict_blocks to verify it's NOT called for async blocks
|
||||
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
|
||||
evict_blocks_calls = []
|
||||
|
||||
def evict_blocks_spy(block_ids):
|
||||
evict_blocks_calls.append(set(block_ids))
|
||||
return original_evict_blocks(block_ids)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
|
||||
):
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify evict_blocks was NOT called (async blocks excluded from eviction)
|
||||
assert len(evict_blocks_calls) == 0, (
|
||||
f"evict_blocks should not be called for async-only invalid blocks, "
|
||||
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
|
||||
)
|
||||
|
||||
# request should still be waiting (not finished with error due to recompute policy)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# verify num_computed_tokens was truncated to before invalid block
|
||||
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_valid_tokens
|
||||
|
||||
# verify invalid block still has no hash (was not evicted)
|
||||
assert block.block_hash is None, (
|
||||
f"Async loading blocks shouldn't be cached or evicted. "
|
||||
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
|
||||
)
|
||||
|
||||
# now simulate async transfer completing
|
||||
model_runner_output_2 = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request.request_id},
|
||||
invalid_block_ids=None,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
|
||||
|
||||
# verify request is now marked as finished receiving and ready to be processed
|
||||
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# critical: verify invalid block still has no hash before recompute
|
||||
# the async transfer invalid data was never cached
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should not be cached before recompute "
|
||||
f"(hash should be None), but hash is {block.block_hash}"
|
||||
)
|
||||
|
||||
# critical end-to-end test: spy on cache_blocks to verify it's called with
|
||||
# the truncated num_computed_tokens value
|
||||
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
|
||||
cache_blocks_calls = []
|
||||
|
||||
def cache_blocks_spy(req, num_tokens):
|
||||
cache_blocks_calls.append((req.request_id, num_tokens))
|
||||
return original_cache_blocks(req, num_tokens)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
|
||||
):
|
||||
# call schedule() again - this triggers _update_waiting_for_remote_kv()
|
||||
# which should call cache_blocks with the truncated value
|
||||
recompute_scheduler.schedule()
|
||||
|
||||
# verify cache_blocks was called with the truncated value
|
||||
assert len(cache_blocks_calls) == 1, (
|
||||
f"cache_blocks should be called exactly once, "
|
||||
f"got {len(cache_blocks_calls)} calls"
|
||||
)
|
||||
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
|
||||
assert cached_req_id == request.request_id
|
||||
assert cached_num_tokens == expected_valid_tokens, (
|
||||
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
|
||||
f"but was called with {cached_num_tokens}"
|
||||
)
|
||||
|
||||
# request should now be RUNNING (scheduled immediately after transfer completes)
|
||||
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
|
||||
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
|
||||
assert request.num_computed_tokens >= expected_valid_tokens, (
|
||||
f"num_computed_tokens should be at least {expected_valid_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
|
||||
# request should no longer be in the failed/finished receiving sets
|
||||
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
|
||||
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# request should be in the running queue
|
||||
assert request in recompute_scheduler.running
|
||||
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
@ -0,0 +1,756 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
|
||||
LMCacheConnectorV1,
|
||||
LMCacheKVEvents,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lmcache_engine_event():
|
||||
"""Create a mock event object that mimics what the lmcache engine returns."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(
|
||||
self,
|
||||
block_hashes,
|
||||
parent_block_hash,
|
||||
token_ids,
|
||||
lora_id,
|
||||
block_size,
|
||||
medium,
|
||||
):
|
||||
self.block_hashes = block_hashes
|
||||
self.parent_block_hash = parent_block_hash
|
||||
self.token_ids = token_ids
|
||||
self.lora_id = lora_id
|
||||
self.block_size = block_size
|
||||
self.medium = medium
|
||||
|
||||
return MockEvent(
|
||||
block_hashes=["hash1", "hash2"],
|
||||
parent_block_hash="parent_hash",
|
||||
token_ids=[1, 2, 3, 4],
|
||||
lora_id=None,
|
||||
block_size=16,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connector():
|
||||
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
|
||||
connector = MagicMock(spec=LMCacheConnectorV1)
|
||||
connector._kv_cache_events = None
|
||||
connector._lmcache_engine = MagicMock()
|
||||
|
||||
# Make the methods use the real implementation
|
||||
connector.get_kv_connector_kv_cache_events = (
|
||||
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.update_connector_output = (
|
||||
LMCacheConnectorV1.update_connector_output.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.take_events = LMCacheConnectorV1.take_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
class TestGetKVConnectorKVCacheEvents:
|
||||
"""Test get_kv_connector_kv_cache_events method."""
|
||||
|
||||
def test_returns_none_when_no_events(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine has no events."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
|
||||
|
||||
def test_returns_none_when_empty_list(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine returns empty list."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = []
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test conversion of a single event from lmcache engine format."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
assert result.get_number_of_workers() == 1
|
||||
|
||||
events = result.get_all_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], BlockStored)
|
||||
assert events[0].block_hashes == ["hash1", "hash2"]
|
||||
assert events[0].parent_block_hash == "parent_hash"
|
||||
assert events[0].token_ids == [1, 2, 3, 4]
|
||||
assert events[0].lora_id is None
|
||||
assert events[0].block_size == 16
|
||||
assert events[0].medium == "GPU"
|
||||
|
||||
def test_converts_multiple_events(self, mock_connector):
|
||||
"""Test conversion of multiple events from lmcache engine format."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, i):
|
||||
self.block_hashes = [f"hash{i}"]
|
||||
self.parent_block_hash = f"parent{i}"
|
||||
self.token_ids = [i]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
events = [MockEvent(i) for i in range(5)]
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = events
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
|
||||
converted_events = result.get_all_events()
|
||||
assert len(converted_events) == 5
|
||||
|
||||
for i, event in enumerate(converted_events):
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == [f"hash{i}"]
|
||||
assert event.parent_block_hash == f"parent{i}"
|
||||
assert event.token_ids == [i]
|
||||
|
||||
def test_preserves_event_attributes(self, mock_connector):
|
||||
"""Test that all event attributes are correctly preserved."""
|
||||
|
||||
class MockEventWithLora:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
|
||||
self.parent_block_hash = "parent_xyz"
|
||||
self.token_ids = [100, 200, 300]
|
||||
self.lora_id = 42
|
||||
self.block_size = 32
|
||||
self.medium = "DISK"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventWithLora()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
event = events[0]
|
||||
|
||||
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
|
||||
assert event.parent_block_hash == "parent_xyz"
|
||||
assert event.token_ids == [100, 200, 300]
|
||||
assert event.lora_id == 42
|
||||
assert event.block_size == 32
|
||||
assert event.medium == "DISK"
|
||||
|
||||
def test_handles_none_parent_block_hash(self, mock_connector):
|
||||
"""Test handling of events with None parent_block_hash."""
|
||||
|
||||
class MockEventNoParent:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash1"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1, 2]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventNoParent()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
assert events[0].parent_block_hash is None
|
||||
|
||||
|
||||
class TestUpdateConnectorOutput:
|
||||
"""Test update_connector_output method."""
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that method returns early when kv_cache_events is None."""
|
||||
connector_output = KVConnectorOutput(kv_cache_events=None)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that method returns early when kv_cache_events is not
|
||||
LMCacheKVEvents."""
|
||||
# Create a mock object that is not LMCacheKVEvents
|
||||
fake_events = MagicMock()
|
||||
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_sets_kv_cache_events_when_none(self, mock_connector):
|
||||
"""Test that _kv_cache_events is set when it was None."""
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1, 2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is kv_events
|
||||
|
||||
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
|
||||
"""Test that events are added when _kv_cache_events already exists."""
|
||||
# Set up existing events
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
existing_events.add_events([event1])
|
||||
existing_events.add_events([event1]) # Simulate 2 workers reporting
|
||||
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events to add
|
||||
new_events = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event2])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Check that events were added
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3 # 2 from existing + 1 from new
|
||||
assert event1 in all_events
|
||||
assert event2 in all_events
|
||||
|
||||
def test_increments_workers_when_kv_cache_events_already_exists(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that worker count is incremented correctly."""
|
||||
# Set up existing events with 2 workers
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events from 3 workers
|
||||
new_events = LMCacheKVEvents(num_workers=3)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Worker count should be 2 + 3 = 5
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
|
||||
|
||||
def test_multiple_updates(self, mock_connector):
|
||||
"""Test multiple consecutive updates."""
|
||||
# First update
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events2.add_events([event2])
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Third update
|
||||
events3 = LMCacheKVEvents(num_workers=1)
|
||||
event3 = BlockStored(
|
||||
block_hashes=["hash3"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events3.add_events([event3])
|
||||
output3 = KVConnectorOutput(kv_cache_events=events3)
|
||||
mock_connector.update_connector_output(output3)
|
||||
|
||||
# Check final state
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
|
||||
|
||||
def test_updates_with_empty_events(self, mock_connector):
|
||||
"""Test updating with empty event lists."""
|
||||
# First update with actual events
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update with empty events
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
# No events added
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Should still have the original event
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 1
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
|
||||
|
||||
|
||||
class TestTakeEvents:
|
||||
"""Test take_events method."""
|
||||
|
||||
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that nothing is yielded when _kv_cache_events is None."""
|
||||
mock_connector._kv_cache_events = None
|
||||
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
assert events == []
|
||||
|
||||
def test_yields_events_and_clears(self, mock_connector):
|
||||
"""Test that events are yielded and then cleared."""
|
||||
# Set up events
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event1, event2])
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Check that events were yielded
|
||||
assert len(events) == 2
|
||||
assert event1 in events
|
||||
assert event2 in events
|
||||
|
||||
# Check that _kv_cache_events was cleared
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_aggregates_before_yielding(self, mock_connector):
|
||||
"""Test that events are aggregated before yielding."""
|
||||
# Set up events from multiple workers
|
||||
kv_events = LMCacheKVEvents(num_workers=3)
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
uncommon_event = BlockStored(
|
||||
block_hashes=["hash_uncommon"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# All 3 workers report common_event
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
|
||||
# Only 1 worker reports uncommon_event
|
||||
kv_events.add_events([uncommon_event])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Only the common event should be yielded
|
||||
assert len(events) == 1
|
||||
assert events[0] == common_event
|
||||
|
||||
def test_multiple_take_events_calls(self, mock_connector):
|
||||
"""Test calling take_events multiple times."""
|
||||
# First call with events
|
||||
kv_events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events1.add_events([event1])
|
||||
mock_connector._kv_cache_events = kv_events1
|
||||
|
||||
events1 = list(mock_connector.take_events())
|
||||
assert len(events1) == 1
|
||||
assert events1[0] == event1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
# Second call with no events
|
||||
events2 = list(mock_connector.take_events())
|
||||
assert events2 == []
|
||||
|
||||
# Third call after adding new events
|
||||
kv_events2 = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events2.add_events([event2])
|
||||
mock_connector._kv_cache_events = kv_events2
|
||||
|
||||
events3 = list(mock_connector.take_events())
|
||||
assert len(events3) == 1
|
||||
assert events3[0] == event2
|
||||
|
||||
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
|
||||
"""Test that nothing is yielded if aggregation removes all events."""
|
||||
# Set up events from 2 workers with no common events
|
||||
kv_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Worker 1 reports event1
|
||||
kv_events.add_events([event1])
|
||||
# Worker 2 reports event2
|
||||
kv_events.add_events([event2])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# No common events, so nothing should be yielded
|
||||
assert events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test a complete workflow from getting events to taking them."""
|
||||
# Step 1: Get events from lmcache engine
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is not None
|
||||
assert len(kv_events.get_all_events()) == 1
|
||||
|
||||
# Step 2: Update connector output (simulate receiving from worker)
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
assert mock_connector._kv_cache_events is not None
|
||||
|
||||
# Step 3: Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert len(taken_events) == 1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_multiple_workers_workflow(self, mock_connector):
|
||||
"""Test workflow with multiple workers."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, hash_val):
|
||||
self.block_hashes = [hash_val]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
# Worker 1
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker1"),
|
||||
]
|
||||
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Worker 2
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker2"),
|
||||
]
|
||||
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Take events (should only get common events)
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# With aggregation, only events reported by both workers should be present
|
||||
# In this case, hash_common was reported by both
|
||||
event_hashes = [e.block_hashes[0] for e in taken_events]
|
||||
assert "hash_common" in event_hashes
|
||||
|
||||
def test_empty_workflow(self, mock_connector):
|
||||
"""Test workflow when there are no events at any stage."""
|
||||
# Get events returns None
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is None
|
||||
|
||||
# Update with None
|
||||
output = KVConnectorOutput(kv_cache_events=None)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert taken_events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_repeated_cycles(self, mock_connector):
|
||||
"""Test multiple cycles of the complete workflow."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, cycle_num):
|
||||
self.block_hashes = [f"hash_cycle_{cycle_num}"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [cycle_num]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
for cycle in range(3):
|
||||
# Get events
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent(cycle)
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
# Update
|
||||
output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# Verify
|
||||
assert len(taken_events) == 1
|
||||
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_lmcache_kv_events_aggregation(self):
|
||||
"""
|
||||
Test LMCacheKVEvents aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
# Define common and unique events
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash="parent_common",
|
||||
token_ids=[1, 2, 3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker1_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker1"],
|
||||
parent_block_hash="parent_w1",
|
||||
token_ids=[4, 5],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker2_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker2"],
|
||||
parent_block_hash="parent_w2",
|
||||
token_ids=[6, 7],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker3_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker3"],
|
||||
parent_block_hash="parent_w3",
|
||||
token_ids=[8, 9],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Create events for each worker
|
||||
# Worker 0: reports common event and its unique event
|
||||
worker0_events = LMCacheKVEvents(num_workers=1)
|
||||
worker0_events.add_events([common_event, worker1_unique_event])
|
||||
|
||||
# Worker 1: reports common event and its unique event
|
||||
worker1_events = LMCacheKVEvents(num_workers=1)
|
||||
worker1_events.add_events([common_event, worker2_unique_event])
|
||||
|
||||
# Worker 2: reports common event and its unique event
|
||||
worker2_events = LMCacheKVEvents(num_workers=1)
|
||||
worker2_events.add_events([common_event, worker3_unique_event])
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_events in enumerate(
|
||||
[worker0_events, worker1_events, worker2_events]
|
||||
):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]], # dummy token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2
|
||||
else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0
|
||||
else None, # Workers 1,2 finished receiving
|
||||
kv_cache_events=worker_events,
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
|
||||
|
||||
assert isinstance(kv_cache_events, LMCacheKVEvents)
|
||||
|
||||
# After aggregation, events should be combined from all workers
|
||||
# The aggregator doesn't automatically aggregate events, so we need to call
|
||||
# aggregate() to get only common events
|
||||
kv_cache_events.aggregate()
|
||||
aggregated_events = kv_cache_events.get_all_events()
|
||||
|
||||
# Only the common event should remain after aggregation
|
||||
# because it's the only event reported by all 3 workers
|
||||
assert len(aggregated_events) == 1
|
||||
assert aggregated_events[0] == common_event
|
||||
|
||||
# Verify the common event properties
|
||||
assert aggregated_events[0].block_hashes == ["hash_common"]
|
||||
assert aggregated_events[0].parent_block_hash == "parent_common"
|
||||
assert aggregated_events[0].token_ids == [1, 2, 3]
|
||||
@ -7,7 +7,7 @@ Here we break down the requirements in 2 steps:
|
||||
1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this.
|
||||
2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine.
|
||||
|
||||
2 is necessary for multi-node deployment.
|
||||
Step 2 is necessary for multi-node deployment.
|
||||
|
||||
All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`.
|
||||
|
||||
@ -23,6 +23,6 @@ TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh
|
||||
Additional step for multi-node deployment:
|
||||
|
||||
```bash
|
||||
sudo bash configure_system_drivers.sh
|
||||
sudo bash configure_system_drivers.sh # update-initramfs can take several minutes
|
||||
sudo reboot # Reboot is required to load the new driver
|
||||
```
|
||||
|
||||
@ -294,6 +294,12 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
# Some features like decode context parallelism require the softmax lse.
|
||||
can_return_lse_for_decode: bool = False
|
||||
|
||||
# Whether the attention impl supports Prefill Context Parallelism.
|
||||
supports_pcp: bool = False
|
||||
# Whether the attention impl(or ops) supports MTP
|
||||
# when cp_kv_cache_interleave_size > 1
|
||||
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
|
||||
|
||||
# some attention backends might not always want to return lse
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
@ -252,35 +252,3 @@ def register_backend(
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Backwards compatibility alias for plugins
|
||||
class _BackendMeta(type):
|
||||
"""Metaclass to provide deprecation warnings when accessing _Backend."""
|
||||
|
||||
def __getattribute__(cls, name: str):
|
||||
if name not in ("__class__", "__mro__", "__name__"):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return getattr(AttentionBackendEnum, name)
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return AttentionBackendEnum[name]
|
||||
|
||||
|
||||
class _Backend(metaclass=_BackendMeta):
|
||||
"""Deprecated: Use AttentionBackendEnum instead.
|
||||
|
||||
This class is provided for backwards compatibility with plugins
|
||||
and will be removed in a future release.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
from functools import cache
|
||||
from typing import cast, get_args
|
||||
|
||||
@ -73,39 +72,18 @@ def _cached_get_attn_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
||||
if "use_v1" in sig.parameters:
|
||||
logger.warning_once(
|
||||
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
|
||||
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
|
||||
"remove it from your plugin code."
|
||||
)
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
True, # use_v1
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
else:
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
|
||||
@ -788,7 +788,7 @@ async def benchmark(
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
||||
"Total token throughput (tok/s):", metrics.total_token_throughput
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
with aot_context:
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
@ -17,7 +17,6 @@ from vllm.config.utils import (
|
||||
Range,
|
||||
config,
|
||||
get_hash_factors,
|
||||
handle_deprecated,
|
||||
hash_factors,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
@ -127,27 +126,6 @@ class PassConfig:
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
|
||||
# Deprecated flags
|
||||
enable_fusion: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
|
||||
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
|
||||
"""
|
||||
enable_attn_fusion: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_noop: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_sequence_parallelism: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use enable_sp instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_async_tp: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_fi_allreduce_fusion: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
@ -206,15 +184,7 @@ class PassConfig:
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
"""
|
||||
|
||||
ignored_fields = [
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_async_tp",
|
||||
"enable_fi_allreduce_fusion",
|
||||
]
|
||||
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
|
||||
return hash_factors(get_hash_factors(self, set()))
|
||||
|
||||
@field_validator(
|
||||
"fuse_norm_quant",
|
||||
@ -224,12 +194,6 @@ class PassConfig:
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_async_tp",
|
||||
"enable_fi_allreduce_fusion",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@ -242,49 +206,6 @@ class PassConfig:
|
||||
def __post_init__(self) -> None:
|
||||
# Handle deprecation and defaults
|
||||
|
||||
# Map old flags to new flags and issue warnings
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fusion",
|
||||
["fuse_norm_quant", "fuse_act_quant"],
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_attn_fusion",
|
||||
"fuse_attn_quant",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_sequence_parallelism",
|
||||
"enable_sp",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_async_tp",
|
||||
"fuse_gemm_comms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fi_allreduce_fusion",
|
||||
"fuse_allreduce_rms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_noop",
|
||||
"eliminate_noops",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
if not self.eliminate_noops:
|
||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||
logger.warning_once(
|
||||
|
||||
@ -64,6 +64,11 @@ class KVTransferConfig:
|
||||
enable_permute_local_kv: bool = False
|
||||
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||
|
||||
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
|
||||
"""Policy for handling KV cache load failures.
|
||||
'recompute': reschedule the request to recompute failed blocks (default)
|
||||
'fail': immediately fail the request with an error finish reason"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -73,17 +73,6 @@ logger = init_logger(__name__)
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TaskOption = Literal[
|
||||
"auto",
|
||||
"generate",
|
||||
"embedding",
|
||||
"embed",
|
||||
"classify",
|
||||
"score",
|
||||
"reward",
|
||||
"transcription",
|
||||
"draft",
|
||||
]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal[
|
||||
@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
|
||||
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
|
||||
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
|
||||
"generate": ["generate", "transcription"],
|
||||
"pooling": ["embedding", "embed", "classify", "score", "reward"],
|
||||
"draft": ["draft"],
|
||||
}
|
||||
|
||||
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
||||
"generate": [],
|
||||
"pooling": ["embed", "classify", "reward"],
|
||||
@ -126,12 +109,6 @@ class ModelConfig:
|
||||
"""Convert the model using adapters defined in
|
||||
[vllm.model_executor.models.adapters][]. The most common use case is to
|
||||
adapt a text generation model to be used for pooling tasks."""
|
||||
task: TaskOption | None = None
|
||||
"""[DEPRECATED] The task to use the model for. If the model supports more
|
||||
than one model runner, this is used to select which model runner to run.
|
||||
|
||||
Note that the model may support other tasks using the same model runner.
|
||||
"""
|
||||
tokenizer: SkipValidation[str] = None # type: ignore
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
@ -335,7 +312,6 @@ class ModelConfig:
|
||||
ignored_factors = {
|
||||
"runner",
|
||||
"convert",
|
||||
"task",
|
||||
"tokenizer",
|
||||
"tokenizer_mode",
|
||||
"seed",
|
||||
@ -510,97 +486,6 @@ class ModelConfig:
|
||||
is_generative_model = registry.is_text_generation_model(architectures, self)
|
||||
is_pooling_model = registry.is_pooling_model(architectures, self)
|
||||
|
||||
def _task_to_convert(task: TaskOption) -> ConvertType:
|
||||
if task == "embedding" or task == "embed":
|
||||
return "embed"
|
||||
if task == "classify":
|
||||
return "classify"
|
||||
if task == "reward":
|
||||
logger.warning(
|
||||
"Pooling models now default support all pooling; "
|
||||
"you can use it without any settings."
|
||||
)
|
||||
return "embed"
|
||||
if task == "score":
|
||||
new_task = self._get_default_pooling_task(architectures)
|
||||
return "classify" if new_task == "classify" else "embed"
|
||||
|
||||
return "none"
|
||||
|
||||
if self.task is not None:
|
||||
runner: RunnerOption = "auto"
|
||||
convert: ConvertOption = "auto"
|
||||
msg_prefix = (
|
||||
"The 'task' option has been deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0, whichever comes first."
|
||||
)
|
||||
msg_hint = "Please remove this option."
|
||||
|
||||
is_generative_task = self.task in _RUNNER_TASKS["generate"]
|
||||
is_pooling_task = self.task in _RUNNER_TASKS["pooling"]
|
||||
|
||||
if is_generative_model and is_pooling_model:
|
||||
if is_generative_task:
|
||||
runner = "generate"
|
||||
convert = "auto"
|
||||
msg_hint = (
|
||||
"Please replace this option with `--runner "
|
||||
"generate` to continue using this model "
|
||||
"as a generative model."
|
||||
)
|
||||
elif is_pooling_task:
|
||||
runner = "pooling"
|
||||
convert = "auto"
|
||||
msg_hint = (
|
||||
"Please replace this option with `--runner "
|
||||
"pooling` to continue using this model "
|
||||
"as a pooling model."
|
||||
)
|
||||
else: # task == "auto"
|
||||
pass
|
||||
elif is_generative_model or is_pooling_model:
|
||||
if is_generative_task:
|
||||
runner = "generate"
|
||||
convert = "auto"
|
||||
msg_hint = "Please remove this option"
|
||||
elif is_pooling_task:
|
||||
runner = "pooling"
|
||||
convert = _task_to_convert(self.task)
|
||||
msg_hint = (
|
||||
"Please replace this option with `--convert "
|
||||
f"{convert}` to continue using this model "
|
||||
"as a pooling model."
|
||||
)
|
||||
else: # task == "auto"
|
||||
pass
|
||||
else:
|
||||
# Neither generative nor pooling model - try to convert if possible
|
||||
if is_pooling_task:
|
||||
runner = "pooling"
|
||||
convert = _task_to_convert(self.task)
|
||||
msg_hint = (
|
||||
"Please replace this option with `--runner pooling "
|
||||
f"--convert {convert}` to continue using this model "
|
||||
"as a pooling model."
|
||||
)
|
||||
else:
|
||||
debug_info = {
|
||||
"architectures": architectures,
|
||||
"is_generative_model": is_generative_model,
|
||||
"is_pooling_model": is_pooling_model,
|
||||
}
|
||||
raise AssertionError(
|
||||
"The model should be a generative or "
|
||||
"pooling model when task is set to "
|
||||
f"{self.task!r}. Found: {debug_info}"
|
||||
)
|
||||
|
||||
self.runner = runner
|
||||
self.convert = convert
|
||||
|
||||
msg = f"{msg_prefix} {msg_hint}"
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
self.runner_type = self._get_runner_type(architectures, self.runner)
|
||||
self.convert_type = self._get_convert_type(
|
||||
architectures, self.runner_type, self.convert
|
||||
@ -654,6 +539,11 @@ class ModelConfig:
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.mm_processor_cache_gb = 0
|
||||
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
|
||||
|
||||
# Init multimodal config if needed
|
||||
if self._model_info.supports_multimodal:
|
||||
if (
|
||||
@ -903,6 +793,13 @@ class ModelConfig:
|
||||
runner_type: RunnerType,
|
||||
convert: ConvertOption,
|
||||
) -> ConvertType:
|
||||
if convert == "reward":
|
||||
logger.warning(
|
||||
"`--convert reward` is deprecated and will be removed in v0.15. "
|
||||
"Please use `--convert embed` instead."
|
||||
)
|
||||
return "embed"
|
||||
|
||||
if convert != "auto":
|
||||
return convert
|
||||
|
||||
@ -918,22 +815,6 @@ class ModelConfig:
|
||||
|
||||
return convert_type
|
||||
|
||||
def _get_default_pooling_task(
|
||||
self,
|
||||
architectures: list[str],
|
||||
) -> Literal["embed", "classify", "reward"]:
|
||||
if self.registry.is_cross_encoder_model(architectures, self):
|
||||
return "classify"
|
||||
|
||||
for arch in architectures:
|
||||
match = try_match_architecture_defaults(arch, runner_type="pooling")
|
||||
if match:
|
||||
_, (_, convert_type) = match
|
||||
assert convert_type != "none"
|
||||
return convert_type
|
||||
|
||||
return "embed"
|
||||
|
||||
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
||||
quant_cfg = getattr(hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
|
||||
@ -317,11 +317,6 @@ class ParallelConfig:
|
||||
"num_redundant_experts."
|
||||
)
|
||||
|
||||
if self.prefill_context_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Prefill context parallelism is not fully supported. "
|
||||
"Please set prefill_context_parallel_size to 1."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
|
||||
@ -111,13 +111,15 @@ class PoolerConfig:
|
||||
def get_use_activation(o: object):
|
||||
if softmax := getattr(o, "softmax", None) is not None:
|
||||
logger.warning_once(
|
||||
"softmax will be deprecated, please use use_activation instead."
|
||||
"softmax will be deprecated and will be removed in v0.15. "
|
||||
"Please use use_activation instead."
|
||||
)
|
||||
return softmax
|
||||
|
||||
if activation := getattr(o, "activation", None) is not None:
|
||||
logger.warning_once(
|
||||
"activation will be deprecated, please use use_activation instead."
|
||||
"activation will be deprecated and will be removed in v0.15. "
|
||||
"Please use use_activation instead."
|
||||
)
|
||||
return activation
|
||||
|
||||
|
||||
@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
)
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
def getattr_iter(
|
||||
object: object, names: Iterable[str], default: Any, warn: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
|
||||
In the case where the first name in `names` is the preferred name, and
|
||||
any other names are deprecated aliases, setting `warn=True` will log a
|
||||
warning when a deprecated name is used.
|
||||
"""
|
||||
for name in names:
|
||||
for i, name in enumerate(names):
|
||||
if hasattr(object, name):
|
||||
if warn and i > 0:
|
||||
logger.warning_once(
|
||||
"%s contains a deprecated attribute name '%s'. "
|
||||
"Please use the preferred attribute name '%s' instead.",
|
||||
type(object).__name__,
|
||||
name,
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
@ -666,8 +666,9 @@ class VllmConfig:
|
||||
|
||||
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||
self._apply_optimization_level_defaults(default_config)
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
|
||||
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
|
||||
):
|
||||
logger.info(
|
||||
@ -692,22 +693,29 @@ class VllmConfig:
|
||||
|
||||
if current_platform.support_static_graph_mode():
|
||||
# if cudagraph_mode has full cudagraphs, we need to check support
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
and self.model_config is not None
|
||||
):
|
||||
if self.model_config.pooler_config is not None:
|
||||
if model_config := self.model_config:
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
and model_config.pooler_config is not None
|
||||
):
|
||||
logger.warning_once(
|
||||
"Pooling models do not support full cudagraphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
logger.warning_once(
|
||||
"Encoder-decoder models do not support full cudagraphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
elif (
|
||||
model_config.is_encoder_decoder
|
||||
and self.compilation_config.cudagraph_mode
|
||||
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
|
||||
):
|
||||
logger.info_once(
|
||||
"Encoder-decoder models do not support %s. "
|
||||
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
|
||||
self.compilation_config.cudagraph_mode.name,
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_DECODE_ONLY
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and self.model_config.enforce_eager:
|
||||
@ -742,27 +750,17 @@ class VllmConfig:
|
||||
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
||||
self._set_compile_ranges()
|
||||
|
||||
if self.model_config and self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
if (
|
||||
self.model_config
|
||||
and self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens,
|
||||
)
|
||||
if (
|
||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
|
||||
if (
|
||||
self.kv_events_config is not None
|
||||
@ -812,11 +810,6 @@ class VllmConfig:
|
||||
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
||||
)
|
||||
|
||||
assert (
|
||||
self.parallel_config.cp_kv_cache_interleave_size == 1
|
||||
or self.speculative_config is None
|
||||
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
|
||||
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
self.compilation_config.set_splitting_ops_for_v1(
|
||||
all2all_backend=self.parallel_config.all2all_backend,
|
||||
@ -1006,7 +999,7 @@ class VllmConfig:
|
||||
max_graph_size = min(max_num_seqs * 2, 512)
|
||||
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
|
||||
# up to max_graph_size
|
||||
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
||||
cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
||||
range(256, max_graph_size + 1, 16))
|
||||
|
||||
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
|
||||
|
||||
|
||||
# Memory fence for cross-process shared memory visibility.
|
||||
# Required for correct producer-consumer synchronization when using
|
||||
# shared memory without locks.
|
||||
_memory_fence_lock = threading.Lock()
|
||||
|
||||
|
||||
def memory_fence():
|
||||
"""
|
||||
Full memory barrier for shared memory synchronization.
|
||||
|
||||
Ensures all prior memory writes are visible to other processes before
|
||||
any subsequent reads. This is critical for lock-free producer-consumer
|
||||
patterns using shared memory.
|
||||
|
||||
Implementation acquires and immediately releases a lock. Python's
|
||||
threading.Lock provides sequentially consistent memory barrier semantics
|
||||
across all major platforms (POSIX, Windows). This is a lightweight
|
||||
operation (~20ns) that guarantees:
|
||||
- All stores before the barrier are visible to other threads/processes
|
||||
- All loads after the barrier see the latest values
|
||||
"""
|
||||
# Lock acquire/release provides full memory barrier semantics.
|
||||
# Using context manager ensures lock release even on exceptions.
|
||||
with _memory_fence_lock:
|
||||
pass
|
||||
|
||||
|
||||
def to_bytes_big(value: int, size: int) -> bytes:
|
||||
return value.to_bytes(size, byteorder="big")
|
||||
|
||||
@ -414,6 +442,10 @@ class MessageQueue:
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
# Memory fence ensures we see the latest read flags from readers.
|
||||
# Without this, we may read stale flags from our CPU cache and
|
||||
# spin indefinitely even though readers have completed.
|
||||
memory_fence()
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
@ -458,6 +490,10 @@ class MessageQueue:
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
# Memory fence ensures the write is visible to readers on other cores
|
||||
# before we proceed. Without this, readers may spin indefinitely
|
||||
# waiting for a write that's stuck in our CPU's store buffer.
|
||||
memory_fence()
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
@ -473,6 +509,10 @@ class MessageQueue:
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
# Memory fence ensures we see the latest writes from the writer.
|
||||
# Without this, we may read stale flags from our CPU cache
|
||||
# and spin indefinitely even though writer has updated them.
|
||||
memory_fence()
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
@ -513,6 +553,10 @@ class MessageQueue:
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
# Memory fence ensures the read flag is visible to the writer.
|
||||
# Without this, writer may not see our read completion and
|
||||
# could wait indefinitely for all readers to finish.
|
||||
memory_fence()
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
|
||||
@ -322,6 +322,9 @@ async def transfer_layer(
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
# A buffer to hold the expert weights in one layer during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
|
||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
|
||||
@ -5,7 +5,7 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections import Counter, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from itertools import count
|
||||
@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
|
||||
lora_id: int | None
|
||||
medium: str | None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
tuple(self.block_hashes),
|
||||
self.parent_block_hash,
|
||||
tuple(self.token_ids),
|
||||
self.block_size,
|
||||
self.lora_id,
|
||||
self.medium,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: str | None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((tuple(self.block_hashes), self.medium))
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
|
||||
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
|
||||
|
||||
|
||||
class KVEventAggregator:
|
||||
"""
|
||||
Aggregates KV events across multiple workers.
|
||||
Tracks how many times each event appears and returns only those
|
||||
that were emitted by all workers.
|
||||
"""
|
||||
|
||||
__slots__ = ("_event_counter", "_num_workers")
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
if num_workers <= 0:
|
||||
raise ValueError("num_workers must be greater than zero.")
|
||||
self._event_counter: Counter[KVCacheEvent] = Counter()
|
||||
self._num_workers: int = num_workers
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
"""
|
||||
Add events from a worker batch.
|
||||
|
||||
:param events: List of KVCacheEvent objects.
|
||||
"""
|
||||
if not isinstance(events, list):
|
||||
raise TypeError("events must be a list of KVCacheEvent.")
|
||||
self._event_counter.update(events)
|
||||
|
||||
def get_common_events(self) -> list[KVCacheEvent]:
|
||||
"""
|
||||
Return events that appeared in all workers.
|
||||
|
||||
:return: List of events present in all workers.
|
||||
"""
|
||||
return [
|
||||
event
|
||||
for event, count in self._event_counter.items()
|
||||
if count == self._num_workers
|
||||
]
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
"""
|
||||
Return all events for all workers.
|
||||
|
||||
:return: List of events for all workers.
|
||||
"""
|
||||
return list(self._event_counter.elements())
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""
|
||||
Clear all tracked events.
|
||||
"""
|
||||
self._event_counter.clear()
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
"""
|
||||
Increment the number of workers contributing events.
|
||||
|
||||
:param count: Number to increment the workers by.
|
||||
"""
|
||||
if count <= 0:
|
||||
raise ValueError("count must be positive.")
|
||||
self._num_workers += count
|
||||
|
||||
def reset_workers(self) -> None:
|
||||
"""
|
||||
Reset the number of workers to 1.
|
||||
"""
|
||||
self._num_workers = 1
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
"""
|
||||
Return the number of workers.
|
||||
|
||||
:return: int number of workers.
|
||||
"""
|
||||
return self._num_workers
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<KVEventAggregator workers={self._num_workers}, "
|
||||
f"events={len(self._event_counter)}>"
|
||||
)
|
||||
|
||||
|
||||
class KVConnectorKVEvents(ABC):
|
||||
"""
|
||||
Abstract base class for KV events.
|
||||
Acts as a container for KV events from the connector.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self) -> "KVConnectorKVEvents":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_number_of_workers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_events(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
support.
|
||||
|
||||
@ -78,6 +78,7 @@ class KVOutputAggregator:
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
combined_kv_cache_events = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
@ -119,6 +120,19 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
# Combine kv_cache_events from all workers.
|
||||
if combined_kv_cache_events is None:
|
||||
# Use the first worker's kv_cache events as start event list.
|
||||
combined_kv_cache_events = kv_output.kv_cache_events
|
||||
elif kv_cache_events := kv_output.kv_cache_events:
|
||||
assert isinstance(
|
||||
combined_kv_cache_events,
|
||||
type(kv_cache_events),
|
||||
)
|
||||
worker_kv_cache_events = kv_cache_events.get_all_events()
|
||||
combined_kv_cache_events.add_events(worker_kv_cache_events)
|
||||
combined_kv_cache_events.increment_workers(1)
|
||||
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
@ -129,6 +143,7 @@ class KVOutputAggregator:
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
kv_cache_events=combined_kv_cache_events or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
model execution and before cleanup.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
@ -26,6 +31,44 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
"""
|
||||
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "LMCacheKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
# lazy import
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
# Fallback for older versions that don't support this method
|
||||
return set()
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
|
||||
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||
if not events:
|
||||
return None
|
||||
|
||||
blocks: list[BlockStored] = [
|
||||
BlockStored(
|
||||
block_hashes=e.block_hashes,
|
||||
parent_block_hash=e.parent_block_hash,
|
||||
token_ids=e.token_ids,
|
||||
lora_id=e.lora_id,
|
||||
block_size=e.block_size,
|
||||
medium=e.medium,
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||
lmcache_kv_events.add_events(blocks)
|
||||
return lmcache_kv_events
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(
|
||||
kv_cache_events.get_number_of_workers()
|
||||
)
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._lmcache_engine.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
|
||||
@ -27,7 +27,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
|
||||
LMCacheAsyncLookupServer,
|
||||
)
|
||||
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
||||
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
|
||||
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
@ -683,7 +683,7 @@ class LMCacheConnectorV1Impl:
|
||||
self.api_server = InternalAPIServer(self)
|
||||
self.api_server.start()
|
||||
# Launch plugins
|
||||
self.plugin_launcher = PluginLauncher(
|
||||
self.plugin_launcher = RuntimePluginLauncher(
|
||||
self.config,
|
||||
role,
|
||||
self.worker_count,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user